Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions google/cloud/firestore_v1/async_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from google.cloud.firestore_v1 import (
async_aggregation,
async_document,
async_query,
async_vector_query,
transaction,
Expand All @@ -31,9 +30,9 @@
BaseCollectionReference,
_item_to_document_ref,
)
from google.cloud.firestore_v1.document import DocumentReference

if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.query_profile import ExplainOptions
Expand Down Expand Up @@ -140,9 +139,7 @@ async def add(
write_result = await document_ref.create(document_data, **kwargs)
return write_result.update_time, document_ref

def document(
self, document_id: str | None = None
) -> async_document.AsyncDocumentReference:
def document(self, document_id: str | None = None) -> AsyncDocumentReference:
"""Create a sub-document underneath the current collection.

Args:
Expand All @@ -162,7 +159,7 @@ async def list_documents(
page_size: int | None = None,
retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT,
timeout: float | None = None,
) -> AsyncGenerator[DocumentReference, None]:
) -> AsyncGenerator[AsyncDocumentReference, None]:
"""List all subdocuments of the current collection.

Args:
Expand Down
27 changes: 16 additions & 11 deletions google/cloud/firestore_v1/base_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,40 +80,45 @@ def __init__(self, alias: str | None = None):
def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
aggregation_pb = StructuredAggregationQuery.Aggregation()
aggregation_pb.alias = self.alias
if self.alias:
aggregation_pb.alias = self.alias
aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count()
return aggregation_pb


class SumAggregation(BaseAggregation):
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
if isinstance(field_ref, FieldPath):
# convert field path to string
field_ref = field_ref.to_api_repr()
self.field_ref = field_ref
# convert field path to string if needed
field_str = (
field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref
)
self.field_ref: str = field_str
super(SumAggregation, self).__init__(alias=alias)

def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
aggregation_pb = StructuredAggregationQuery.Aggregation()
aggregation_pb.alias = self.alias
if self.alias:
aggregation_pb.alias = self.alias
aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum()
aggregation_pb.sum.field.field_path = self.field_ref
return aggregation_pb


class AvgAggregation(BaseAggregation):
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
if isinstance(field_ref, FieldPath):
# convert field path to string
field_ref = field_ref.to_api_repr()
self.field_ref = field_ref
# convert field path to string if needed
field_str = (
field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref
)
self.field_ref: str = field_str
super(AvgAggregation, self).__init__(alias=alias)

def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
aggregation_pb = StructuredAggregationQuery.Aggregation()
aggregation_pb.alias = self.alias
if self.alias:
aggregation_pb.alias = self.alias
aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg()
aggregation_pb.avg.field.field_path = self.field_ref
return aggregation_pb
Expand Down
16 changes: 11 additions & 5 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
BaseVectorQuery,
DistanceMeasure,
)
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
from google.cloud.firestore_v1.document import DocumentReference
from google.cloud.firestore_v1.field_path import FieldPath
from google.cloud.firestore_v1.query_profile import ExplainOptions
Expand Down Expand Up @@ -129,7 +130,7 @@ def _aggregation_query(self) -> BaseAggregationQuery:
def _vector_query(self) -> BaseVectorQuery:
raise NotImplementedError

def document(self, document_id: Optional[str] = None) -> DocumentReference:
def document(self, document_id: Optional[str] = None):
"""Create a sub-document underneath the current collection.

Args:
Expand All @@ -139,7 +140,7 @@ def document(self, document_id: Optional[str] = None) -> DocumentReference:
uppercase and lowercase and letters.

Returns:
:class:`~google.cloud.firestore_v1.document.DocumentReference`:
:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`:
The child document.
"""
if document_id is None:
Expand Down Expand Up @@ -179,7 +180,7 @@ def _prep_add(
document_id: Optional[str] = None,
retry: retries.Retry | retries.AsyncRetry | object | None = None,
timeout: Optional[float] = None,
) -> Tuple[DocumentReference, dict]:
):
"""Shared setup for async / sync :method:`add`"""
if document_id is None:
document_id = _auto_id()
Expand Down Expand Up @@ -226,7 +227,8 @@ def list_documents(
retry: retries.Retry | retries.AsyncRetry | object | None = None,
timeout: Optional[float] = None,
) -> Union[
Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any]
Generator[DocumentReference, Any, Any],
AsyncGenerator[AsyncDocumentReference, Any],
]:
raise NotImplementedError

Expand Down Expand Up @@ -602,13 +604,17 @@ def _auto_id() -> str:
return "".join(random.choice(_AUTO_ID_CHARS) for _ in range(20))


def _item_to_document_ref(collection_reference, item) -> DocumentReference:
def _item_to_document_ref(collection_reference, item):
"""Convert Document resource to document ref.

Args:
collection_reference (google.api_core.page_iterator.GRPCIterator):
iterator response
item (dict): document resource

Returns:
:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`:
The child document
"""
document_id = item.name.split(_helpers.DOCUMENT_PATH_DELIMITER)[-1]
return collection_reference.document(document_id)
10 changes: 5 additions & 5 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _validate_opation(op_string, value):
class FieldFilter(BaseFilter):
"""Class representation of a Field Filter."""

def __init__(self, field_path, op_string, value=None):
def __init__(self, field_path: str, op_string: str, value: Any | None = None):
self.field_path = field_path
self.value = value
self.op_string = _validate_opation(op_string, value)
Expand All @@ -205,8 +205,8 @@ class BaseCompositeFilter(BaseFilter):

def __init__(
self,
operator=StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED,
filters=None,
operator: int = StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED,
filters: list[BaseFilter] | None = None,
):
self.operator = operator
if filters is None:
Expand Down Expand Up @@ -238,7 +238,7 @@ def _to_pb(self):
class Or(BaseCompositeFilter):
"""Class representation of an OR Filter."""

def __init__(self, filters):
def __init__(self, filters: list[BaseFilter]):
super().__init__(
operator=StructuredQuery.CompositeFilter.Operator.OR, filters=filters
)
Expand All @@ -247,7 +247,7 @@ def __init__(self, filters):
class And(BaseCompositeFilter):
"""Class representation of an AND Filter."""

def __init__(self, filters):
def __init__(self, filters: list[BaseFilter]):
super().__init__(
operator=StructuredQuery.CompositeFilter.Operator.AND, filters=filters
)
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/firestore_v1/bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def wrapper(self, *args, **kwargs):
# For code parity, even `SendMode.serial` scenarios should return
# a future here. Anything else would badly complicate calling code.
result = fn(self, *args, **kwargs)
future = concurrent.futures.Future()
future: concurrent.futures.Future = concurrent.futures.Future()
future.set_result(result)
return future

Expand Down Expand Up @@ -319,6 +319,7 @@ def __init__(
self._total_batches_sent: int = 0
self._total_write_operations: int = 0

self._executor: concurrent.futures.ThreadPoolExecutor
self._ensure_executor()

@staticmethod
Expand Down
14 changes: 7 additions & 7 deletions google/cloud/firestore_v1/field_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,15 @@ class FieldPath(object):
Indicating path of the key to be used.
"""

def __init__(self, *parts):
def __init__(self, *parts: str):
for part in parts:
if not isinstance(part, str) or not part:
error = "One or more components is not a string or is empty."
raise ValueError(error)
self.parts = tuple(parts)

@classmethod
def from_api_repr(cls, api_repr: str):
def from_api_repr(cls, api_repr: str) -> "FieldPath":
"""Factory: create a FieldPath from the string formatted per the API.

Args:
Expand All @@ -288,7 +288,7 @@ def from_api_repr(cls, api_repr: str):
return cls(*parse_field_path(api_repr))

@classmethod
def from_string(cls, path_string: str):
def from_string(cls, path_string: str) -> "FieldPath":
"""Factory: create a FieldPath from a unicode string representation.

This method splits on the character `.` and disallows the
Expand Down Expand Up @@ -351,7 +351,7 @@ def __add__(self, other):
else:
return NotImplemented

def to_api_repr(self):
def to_api_repr(self) -> str:
"""Render a quoted string representation of the FieldPath

Returns:
Expand All @@ -360,7 +360,7 @@ def to_api_repr(self):
"""
return render_field_path(self.parts)

def eq_or_parent(self, other):
def eq_or_parent(self, other) -> bool:
"""Check whether ``other`` is an ancestor.

Returns:
Expand All @@ -369,7 +369,7 @@ def eq_or_parent(self, other):
"""
return self.parts[: len(other.parts)] == other.parts[: len(self.parts)]

def lineage(self):
def lineage(self) -> set["FieldPath"]:
"""Return field paths for all parents.

Returns: Set[:class:`FieldPath`]
Expand All @@ -378,7 +378,7 @@ def lineage(self):
return {FieldPath(*self.parts[:index]) for index in indexes}

@staticmethod
def document_id():
def document_id() -> str:
"""A special FieldPath value to refer to the ID of a document. It can be used
in queries to sort or filter by the document ID.

Expand Down
37 changes: 21 additions & 16 deletions google/cloud/firestore_v1/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import collections
import functools
Expand Down Expand Up @@ -232,7 +233,7 @@ def __init__(
def _init_stream(self):
rpc_request = self._get_rpc_request

self._rpc = ResumableBidiRpc(
self._rpc: ResumableBidiRpc | None = ResumableBidiRpc(
start_rpc=self._api._transport.listen,
should_recover=_should_recover,
should_terminate=_should_terminate,
Expand All @@ -243,7 +244,9 @@ def _init_stream(self):
self._rpc.add_done_callback(self._on_rpc_done)

# The server assigns and updates the resume token.
self._consumer = BackgroundConsumer(self._rpc, self.on_snapshot)
self._consumer: BackgroundConsumer | None = BackgroundConsumer(
self._rpc, self.on_snapshot
)
self._consumer.start()

@classmethod
Expand Down Expand Up @@ -330,16 +333,18 @@ def close(self, reason=None):
return

# Stop consuming messages.
if self.is_active:
_LOGGER.debug("Stopping consumer.")
self._consumer.stop()
self._consumer._on_response = None
if self._consumer:
if self.is_active:
_LOGGER.debug("Stopping consumer.")
self._consumer.stop()
self._consumer._on_response = None
self._consumer = None

self._snapshot_callback = None
self._rpc.close()
self._rpc._initial_request = None
self._rpc._callbacks = []
if self._rpc:
self._rpc.close()
self._rpc._initial_request = None
self._rpc._callbacks = []
self._rpc = None
self._closed = True
_LOGGER.debug("Finished stopping manager.")
Expand Down Expand Up @@ -460,13 +465,13 @@ def on_snapshot(self, proto):
message = f"Unknown target change type: {target_change_type}"
_LOGGER.info(f"on_snapshot: {message}")
self.close(reason=ValueError(message))

try:
# Use 'proto' vs 'pb' for datetime handling
meth(self, proto.target_change)
except Exception as exc2:
_LOGGER.debug(f"meth(proto) exc: {exc2}")
raise
else:
try:
# Use 'proto' vs 'pb' for datetime handling
meth(self, proto.target_change)
except Exception as exc2:
_LOGGER.debug(f"meth(proto) exc: {exc2}")
raise

# NOTE:
# in other implementations, such as node, the backoff is reset here
Expand Down
13 changes: 10 additions & 3 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,16 @@ def pytype(session):
def mypy(session):
"""Verify type hints are mypy compatible."""
session.install("-e", ".")
session.install("mypy", "types-setuptools")
# TODO: also verify types on tests, all of google package
session.run("mypy", "-p", "google.cloud.firestore", "--no-incremental")
session.install("mypy", "types-setuptools", "types-protobuf")
session.run(
"mypy",
"-p",
"google.cloud.firestore_v1",
"--no-incremental",
"--check-untyped-defs",
"--exclude",
"services",
)


@nox.session(python=DEFAULT_PYTHON_VERSION)
Expand Down
Loading
Loading