Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions google/cloud/firestore_v1/base_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from google.cloud.firestore_v1.stream_generator import (
StreamGenerator,
)
from google.cloud.firestore_v1.pipeline_source import PipelineSource

import datetime

Expand Down Expand Up @@ -356,19 +357,20 @@ def stream(
A generator of the query results.
"""

def pipeline(self):
def _build_pipeline(self, source: "PipelineSource"):
"""
Convert this query into a Pipeline

Queries containing a `cursor` or `limit_to_last` are not currently supported

Args:
source: the PipelineSource to build the pipeline off of
Raises:
- ValueError: raised if Query wasn't created with an associated client
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
Returns:
a Pipeline representing the query
"""
# use autoindexer to keep track of which field number to use for un-aliased fields
autoindexer = itertools.count(start=1)
exprs = [a._to_pipeline_expr(autoindexer) for a in self._aggregations]
return self._nested_query.pipeline().aggregate(*exprs)
return self._nested_query._build_pipeline(source).aggregate(*exprs)
7 changes: 5 additions & 2 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
from google.cloud.firestore_v1.document import DocumentReference
from google.cloud.firestore_v1.field_path import FieldPath
from google.cloud.firestore_v1.pipeline_source import PipelineSource
from google.cloud.firestore_v1.query_profile import ExplainOptions
from google.cloud.firestore_v1.query_results import QueryResultsList
from google.cloud.firestore_v1.stream_generator import StreamGenerator
Expand Down Expand Up @@ -602,18 +603,20 @@ def find_nearest(
distance_threshold=distance_threshold,
)

def pipeline(self):
def _build_pipeline(self, source: "PipelineSource"):
"""
Convert this query into a Pipeline

Queries containing a `cursor` or `limit_to_last` are not currently supported

Args:
source: the PipelineSource to build the pipeline off o
Raises:
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
Returns:
a Pipeline representing the query
"""
return self._query().pipeline()
return self._query()._build_pipeline(source)


def _auto_id() -> str:
Expand Down
12 changes: 6 additions & 6 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from google.cloud.firestore_v1.query_profile import ExplainOptions
from google.cloud.firestore_v1.query_results import QueryResultsList
from google.cloud.firestore_v1.stream_generator import StreamGenerator
from google.cloud.firestore_v1.pipeline_source import PipelineSource

import datetime

Expand Down Expand Up @@ -1129,24 +1130,23 @@ def recursive(self: QueryType) -> QueryType:

return copied

def pipeline(self):
def _build_pipeline(self, source: "PipelineSource"):
"""
Convert this query into a Pipeline

Queries containing a `cursor` or `limit_to_last` are not currently supported

Args:
source: the PipelineSource to build the pipeline off of
Raises:
- ValueError: raised if Query wasn't created with an associated client
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
Returns:
a Pipeline representing the query
"""
if not self._client:
raise ValueError("Query does not have an associated client")
if self._all_descendants:
ppl = self._client.pipeline().collection_group(self._parent.id)
ppl = source.collection_group(self._parent.id)
else:
ppl = self._client.pipeline().collection(self._parent._path)
ppl = source.collection(self._parent._path)

# Filters
for filter_ in self._field_filters:
Expand Down
20 changes: 20 additions & 0 deletions google/cloud/firestore_v1/pipeline_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from google.cloud.firestore_v1.client import Client
from google.cloud.firestore_v1.async_client import AsyncClient
from google.cloud.firestore_v1.base_document import BaseDocumentReference
from google.cloud.firestore_v1.base_query import BaseQuery
from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery
from google.cloud.firestore_v1.base_collection import BaseCollectionReference


PipelineType = TypeVar("PipelineType", bound=_BasePipeline)
Expand All @@ -43,6 +46,23 @@ def __init__(self, client: Client | AsyncClient):
def _create_pipeline(self, source_stage):
return self.client._pipeline_cls._create_with_stages(self.client, source_stage)

def create_from(
self, query: "BaseQuery" | "BaseAggregationQuery" | "BaseCollectionReference"
) -> PipelineType:
"""
Create a pipeline from an existing query

Queries containing a `cursor` or `limit_to_last` are not currently supported

Args:
query: the query to build the pipeline off of
Raises:
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
Returns:
a new pipeline instance representing the query
"""
return query._build_pipeline(self)

def collection(self, path: str | tuple[str]) -> PipelineType:
"""
Creates a new Pipeline that operates on a specified Firestore collection.
Expand Down
3 changes: 2 additions & 1 deletion tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def _clean_results(results):
except Exception as e:
# if we expect the query to fail, capture the exception
query_exception = e
pipeline = query.pipeline()
client = query._client
pipeline = client.pipeline().create_from(query)
if query_exception:
# ensure that the pipeline uses same error as query
with pytest.raises(query_exception.__class__):
Expand Down
3 changes: 2 additions & 1 deletion tests/system/test_system_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ def _clean_results(results):
except Exception as e:
# if we expect the query to fail, capture the exception
query_exception = e
pipeline = query.pipeline()
client = query._client
pipeline = client.pipeline().create_from(query)
if query_exception:
# ensure that the pipeline uses same error as query
with pytest.raises(query_exception.__class__):
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/v1/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ def test_aggreation_to_pipeline_sum(field, in_alias, out_alias):
query = make_query(parent)
aggregation_query = make_aggregation_query(query)
aggregation_query.sum(field, alias=in_alias)
pipeline = aggregation_query.pipeline()
pipeline = aggregation_query._build_pipeline(client.pipeline())
assert isinstance(pipeline, Pipeline)
assert len(pipeline.stages) == 2
assert isinstance(pipeline.stages[0], Collection)
Expand Down Expand Up @@ -1071,7 +1071,7 @@ def test_aggreation_to_pipeline_avg(field, in_alias, out_alias):
query = make_query(parent)
aggregation_query = make_aggregation_query(query)
aggregation_query.avg(field, alias=in_alias)
pipeline = aggregation_query.pipeline()
pipeline = aggregation_query._build_pipeline(client.pipeline())
assert isinstance(pipeline, Pipeline)
assert len(pipeline.stages) == 2
assert isinstance(pipeline.stages[0], Collection)
Expand Down Expand Up @@ -1102,7 +1102,7 @@ def test_aggreation_to_pipeline_count(in_alias, out_alias):
query = make_query(parent)
aggregation_query = make_aggregation_query(query)
aggregation_query.count(alias=in_alias)
pipeline = aggregation_query.pipeline()
pipeline = aggregation_query._build_pipeline(client.pipeline())
assert isinstance(pipeline, Pipeline)
assert len(pipeline.stages) == 2
assert isinstance(pipeline.stages[0], Collection)
Expand All @@ -1127,7 +1127,7 @@ def test_aggreation_to_pipeline_count_increment():
aggregation_query = make_aggregation_query(query)
for _ in range(n):
aggregation_query.count()
pipeline = aggregation_query.pipeline()
pipeline = aggregation_query._build_pipeline(client.pipeline())
aggregate_stage = pipeline.stages[1]
assert len(aggregate_stage.accumulators) == n
for i in range(n):
Expand All @@ -1146,7 +1146,7 @@ def test_aggreation_to_pipeline_complex():
aggregation_query.count()
aggregation_query.avg("other")
aggregation_query.sum("final")
pipeline = aggregation_query.pipeline()
pipeline = aggregation_query._build_pipeline(client.pipeline())
assert isinstance(pipeline, Pipeline)
assert len(pipeline.stages) == 3
assert isinstance(pipeline.stages[0], Collection)
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/v1/test_async_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias):
query = make_async_query(parent)
aggregation_query = make_async_aggregation_query(query)
aggregation_query.sum(field, alias=in_alias)
pipeline = aggregation_query.pipeline()
pipeline = aggregation_query._build_pipeline(client.pipeline())
assert isinstance(pipeline, AsyncPipeline)
assert len(pipeline.stages) == 2
assert isinstance(pipeline.stages[0], Collection)
Expand Down Expand Up @@ -747,7 +747,7 @@ def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias):
query = make_async_query(parent)
aggregation_query = make_async_aggregation_query(query)
aggregation_query.avg(field, alias=in_alias)
pipeline = aggregation_query.pipeline()
pipeline = aggregation_query._build_pipeline(client.pipeline())
assert isinstance(pipeline, AsyncPipeline)
assert len(pipeline.stages) == 2
assert isinstance(pipeline.stages[0], Collection)
Expand Down Expand Up @@ -778,7 +778,7 @@ def test_async_aggreation_to_pipeline_count(in_alias, out_alias):
query = make_async_query(parent)
aggregation_query = make_async_aggregation_query(query)
aggregation_query.count(alias=in_alias)
pipeline = aggregation_query.pipeline()
pipeline = aggregation_query._build_pipeline(client.pipeline())
assert isinstance(pipeline, AsyncPipeline)
assert len(pipeline.stages) == 2
assert isinstance(pipeline.stages[0], Collection)
Expand All @@ -803,7 +803,7 @@ def test_aggreation_to_pipeline_count_increment():
aggregation_query = make_async_aggregation_query(query)
for _ in range(n):
aggregation_query.count()
pipeline = aggregation_query.pipeline()
pipeline = aggregation_query._build_pipeline(client.pipeline())
aggregate_stage = pipeline.stages[1]
assert len(aggregate_stage.accumulators) == n
for i in range(n):
Expand All @@ -822,7 +822,7 @@ def test_async_aggreation_to_pipeline_complex():
aggregation_query.count()
aggregation_query.avg("other")
aggregation_query.sum("final")
pipeline = aggregation_query.pipeline()
pipeline = aggregation_query._build_pipeline(client.pipeline())
assert isinstance(pipeline, AsyncPipeline)
assert len(pipeline.stages) == 3
assert isinstance(pipeline.stages[0], Collection)
Expand Down
8 changes: 1 addition & 7 deletions tests/unit/v1/test_async_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,15 +609,9 @@ def test_asynccollectionreference_pipeline():

client = make_async_client()
collection = _make_async_collection_reference("collection", client=client)
pipeline = collection.pipeline()
pipeline = collection._build_pipeline(client.pipeline())
assert isinstance(pipeline, AsyncPipeline)
# should have single "Collection" stage
assert len(pipeline.stages) == 1
assert isinstance(pipeline.stages[0], Collection)
assert pipeline.stages[0].path == "/collection"


def test_asynccollectionreference_pipeline_no_client():
collection = _make_async_collection_reference("collection")
with pytest.raises(ValueError, match="client"):
collection.pipeline()
4 changes: 2 additions & 2 deletions tests/unit/v1/test_async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ def test_asyncquery_collection_pipeline_type():
client = make_async_client()
parent = client.collection("test")
query = parent._query()
ppl = query.pipeline()
ppl = query._build_pipeline(client.pipeline())
assert isinstance(ppl, AsyncPipeline)


Expand All @@ -926,5 +926,5 @@ def test_asyncquery_collectiongroup_pipeline_type():

client = make_async_client()
query = client.collection_group("test")
ppl = query.pipeline()
ppl = query._build_pipeline(client.pipeline())
assert isinstance(ppl, AsyncPipeline)
7 changes: 4 additions & 3 deletions tests/unit/v1/test_base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,11 @@ def test_basecollectionreference_pipeline(mock_query):
_query.return_value = mock_query

collection = _make_base_collection_reference("collection")
pipeline = collection.pipeline()
mock_source = mock.Mock()
pipeline = collection._build_pipeline(mock_source)

mock_query.pipeline.assert_called_once_with()
assert pipeline == mock_query.pipeline.return_value
mock_query._build_pipeline.assert_called_once_with(mock_source)
assert pipeline == mock_query._build_pipeline.return_value


@mock.patch("random.choice")
Expand Down
Loading