Skip to content

Commit 7f96290

Browse files
chore: enable mypy testing (#1057)
1 parent d8e3af1 commit 7f96290

File tree

12 files changed

+115
-59
lines changed

12 files changed

+115
-59
lines changed

google/cloud/firestore_v1/async_collection.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from google.cloud.firestore_v1 import (
2424
async_aggregation,
25-
async_document,
2625
async_query,
2726
async_vector_query,
2827
transaction,
@@ -31,11 +30,10 @@
3130
BaseCollectionReference,
3231
_item_to_document_ref,
3332
)
34-
from google.cloud.firestore_v1.document import DocumentReference
3533

3634
if TYPE_CHECKING: # pragma: NO COVER
3735
import datetime
38-
36+
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
3937
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
4038
from google.cloud.firestore_v1.base_document import DocumentSnapshot
4139
from google.cloud.firestore_v1.query_profile import ExplainOptions
@@ -142,9 +140,7 @@ async def add(
142140
write_result = await document_ref.create(document_data, **kwargs)
143141
return write_result.update_time, document_ref
144142

145-
def document(
146-
self, document_id: str | None = None
147-
) -> async_document.AsyncDocumentReference:
143+
def document(self, document_id: str | None = None) -> AsyncDocumentReference:
148144
"""Create a sub-document underneath the current collection.
149145
150146
Args:
@@ -166,7 +162,7 @@ async def list_documents(
166162
timeout: float | None = None,
167163
*,
168164
read_time: datetime.datetime | None = None,
169-
) -> AsyncGenerator[DocumentReference, None]:
165+
) -> AsyncGenerator[AsyncDocumentReference, None]:
170166
"""List all subdocuments of the current collection.
171167
172168
Args:

google/cloud/firestore_v1/async_query.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,16 @@
2020
"""
2121
from __future__ import annotations
2222

23-
from typing import TYPE_CHECKING, Any, AsyncGenerator, List, Optional, Type
23+
from typing import (
24+
TYPE_CHECKING,
25+
Any,
26+
AsyncGenerator,
27+
List,
28+
Optional,
29+
Type,
30+
Union,
31+
Sequence,
32+
)
2433

2534
from google.api_core import gapic_v1
2635
from google.api_core import retry_async as retries
@@ -256,7 +265,7 @@ async def get(
256265
def find_nearest(
257266
self,
258267
vector_field: str,
259-
query_vector: Vector,
268+
query_vector: Union[Vector, Sequence[float]],
260269
limit: int,
261270
distance_measure: DistanceMeasure,
262271
*,
@@ -269,7 +278,7 @@ def find_nearest(
269278
Args:
270279
vector_field (str): An indexed vector field to search upon. Only documents which contain
271280
vectors whose dimensionality match the query_vector can be returned.
272-
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
281+
query_vector (Vector | Sequence[float]): The query vector that we are searching on. Must be a vector of no more
273282
than 2048 dimensions.
274283
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
275284
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.

google/cloud/firestore_v1/base_aggregation.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,40 +83,45 @@ def __init__(self, alias: str | None = None):
8383
def _to_protobuf(self):
8484
"""Convert this instance to the protobuf representation"""
8585
aggregation_pb = StructuredAggregationQuery.Aggregation()
86-
aggregation_pb.alias = self.alias
86+
if self.alias:
87+
aggregation_pb.alias = self.alias
8788
aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count()
8889
return aggregation_pb
8990

9091

9192
class SumAggregation(BaseAggregation):
9293
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
93-
if isinstance(field_ref, FieldPath):
94-
# convert field path to string
95-
field_ref = field_ref.to_api_repr()
96-
self.field_ref = field_ref
94+
# convert field path to string if needed
95+
field_str = (
96+
field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref
97+
)
98+
self.field_ref: str = field_str
9799
super(SumAggregation, self).__init__(alias=alias)
98100

99101
def _to_protobuf(self):
100102
"""Convert this instance to the protobuf representation"""
101103
aggregation_pb = StructuredAggregationQuery.Aggregation()
102-
aggregation_pb.alias = self.alias
104+
if self.alias:
105+
aggregation_pb.alias = self.alias
103106
aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum()
104107
aggregation_pb.sum.field.field_path = self.field_ref
105108
return aggregation_pb
106109

107110

108111
class AvgAggregation(BaseAggregation):
109112
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
110-
if isinstance(field_ref, FieldPath):
111-
# convert field path to string
112-
field_ref = field_ref.to_api_repr()
113-
self.field_ref = field_ref
113+
# convert field path to string if needed
114+
field_str = (
115+
field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref
116+
)
117+
self.field_ref: str = field_str
114118
super(AvgAggregation, self).__init__(alias=alias)
115119

116120
def _to_protobuf(self):
117121
"""Convert this instance to the protobuf representation"""
118122
aggregation_pb = StructuredAggregationQuery.Aggregation()
119-
aggregation_pb.alias = self.alias
123+
if self.alias:
124+
aggregation_pb.alias = self.alias
120125
aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg()
121126
aggregation_pb.avg.field.field_path = self.field_ref
122127
return aggregation_pb

google/cloud/firestore_v1/base_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def _prep_collections(
476476
read_time: datetime.datetime | None = None,
477477
) -> Tuple[dict, dict]:
478478
"""Shared setup for async/sync :meth:`collections`."""
479-
request = {
479+
request: dict[str, Any] = {
480480
"parent": "{}/documents".format(self._database_string),
481481
}
482482
if read_time is not None:

google/cloud/firestore_v1/base_collection.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
BaseVectorQuery,
4646
DistanceMeasure,
4747
)
48+
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
4849
from google.cloud.firestore_v1.document import DocumentReference
4950
from google.cloud.firestore_v1.field_path import FieldPath
5051
from google.cloud.firestore_v1.query_profile import ExplainOptions
@@ -132,7 +133,7 @@ def _aggregation_query(self) -> BaseAggregationQuery:
132133
def _vector_query(self) -> BaseVectorQuery:
133134
raise NotImplementedError
134135

135-
def document(self, document_id: Optional[str] = None) -> DocumentReference:
136+
def document(self, document_id: Optional[str] = None):
136137
"""Create a sub-document underneath the current collection.
137138
138139
Args:
@@ -142,7 +143,7 @@ def document(self, document_id: Optional[str] = None) -> DocumentReference:
142143
uppercase and lowercase and letters.
143144
144145
Returns:
145-
:class:`~google.cloud.firestore_v1.document.DocumentReference`:
146+
:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`:
146147
The child document.
147148
"""
148149
if document_id is None:
@@ -182,7 +183,7 @@ def _prep_add(
182183
document_id: Optional[str] = None,
183184
retry: retries.Retry | retries.AsyncRetry | object | None = None,
184185
timeout: Optional[float] = None,
185-
) -> Tuple[DocumentReference, dict]:
186+
):
186187
"""Shared setup for async / sync :method:`add`"""
187188
if document_id is None:
188189
document_id = _auto_id()
@@ -234,7 +235,8 @@ def list_documents(
234235
*,
235236
read_time: Optional[datetime.datetime] = None,
236237
) -> Union[
237-
Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any]
238+
Generator[DocumentReference, Any, Any],
239+
AsyncGenerator[AsyncDocumentReference, Any],
238240
]:
239241
raise NotImplementedError
240242

@@ -612,13 +614,17 @@ def _auto_id() -> str:
612614
return "".join(random.choice(_AUTO_ID_CHARS) for _ in range(20))
613615

614616

615-
def _item_to_document_ref(collection_reference, item) -> DocumentReference:
617+
def _item_to_document_ref(collection_reference, item):
616618
"""Convert Document resource to document ref.
617619
618620
Args:
619621
collection_reference (google.api_core.page_iterator.GRPCIterator):
620622
iterator response
621623
item (dict): document resource
624+
625+
Returns:
626+
:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`:
627+
The child document
622628
"""
623629
document_id = item.name.split(_helpers.DOCUMENT_PATH_DELIMITER)[-1]
624630
return collection_reference.document(document_id)

google/cloud/firestore_v1/base_query.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def _validate_opation(op_string, value):
182182
class FieldFilter(BaseFilter):
183183
"""Class representation of a Field Filter."""
184184

185-
def __init__(self, field_path, op_string, value=None):
185+
def __init__(self, field_path: str, op_string: str, value: Any | None = None):
186186
self.field_path = field_path
187187
self.value = value
188188
self.op_string = _validate_opation(op_string, value)
@@ -208,8 +208,8 @@ class BaseCompositeFilter(BaseFilter):
208208

209209
def __init__(
210210
self,
211-
operator=StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED,
212-
filters=None,
211+
operator: int = StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED,
212+
filters: list[BaseFilter] | None = None,
213213
):
214214
self.operator = operator
215215
if filters is None:
@@ -241,7 +241,7 @@ def _to_pb(self):
241241
class Or(BaseCompositeFilter):
242242
"""Class representation of an OR Filter."""
243243

244-
def __init__(self, filters):
244+
def __init__(self, filters: list[BaseFilter]):
245245
super().__init__(
246246
operator=StructuredQuery.CompositeFilter.Operator.OR, filters=filters
247247
)
@@ -250,7 +250,7 @@ def __init__(self, filters):
250250
class And(BaseCompositeFilter):
251251
"""Class representation of an AND Filter."""
252252

253-
def __init__(self, filters):
253+
def __init__(self, filters: list[BaseFilter]):
254254
super().__init__(
255255
operator=StructuredQuery.CompositeFilter.Operator.AND, filters=filters
256256
)

google/cloud/firestore_v1/bulk_writer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def wrapper(self, *args, **kwargs):
110110
# For code parity, even `SendMode.serial` scenarios should return
111111
# a future here. Anything else would badly complicate calling code.
112112
result = fn(self, *args, **kwargs)
113-
future = concurrent.futures.Future()
113+
future: concurrent.futures.Future = concurrent.futures.Future()
114114
future.set_result(result)
115115
return future
116116

@@ -319,6 +319,7 @@ def __init__(
319319
self._total_batches_sent: int = 0
320320
self._total_write_operations: int = 0
321321

322+
self._executor: concurrent.futures.ThreadPoolExecutor
322323
self._ensure_executor()
323324

324325
@staticmethod

google/cloud/firestore_v1/field_path.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,15 @@ class FieldPath(object):
263263
Indicating path of the key to be used.
264264
"""
265265

266-
def __init__(self, *parts):
266+
def __init__(self, *parts: str):
267267
for part in parts:
268268
if not isinstance(part, str) or not part:
269269
error = "One or more components is not a string or is empty."
270270
raise ValueError(error)
271271
self.parts = tuple(parts)
272272

273273
@classmethod
274-
def from_api_repr(cls, api_repr: str):
274+
def from_api_repr(cls, api_repr: str) -> "FieldPath":
275275
"""Factory: create a FieldPath from the string formatted per the API.
276276
277277
Args:
@@ -288,7 +288,7 @@ def from_api_repr(cls, api_repr: str):
288288
return cls(*parse_field_path(api_repr))
289289

290290
@classmethod
291-
def from_string(cls, path_string: str):
291+
def from_string(cls, path_string: str) -> "FieldPath":
292292
"""Factory: create a FieldPath from a unicode string representation.
293293
294294
This method splits on the character `.` and disallows the
@@ -351,7 +351,7 @@ def __add__(self, other):
351351
else:
352352
return NotImplemented
353353

354-
def to_api_repr(self):
354+
def to_api_repr(self) -> str:
355355
"""Render a quoted string representation of the FieldPath
356356
357357
Returns:
@@ -360,7 +360,7 @@ def to_api_repr(self):
360360
"""
361361
return render_field_path(self.parts)
362362

363-
def eq_or_parent(self, other):
363+
def eq_or_parent(self, other) -> bool:
364364
"""Check whether ``other`` is an ancestor.
365365
366366
Returns:
@@ -369,7 +369,7 @@ def eq_or_parent(self, other):
369369
"""
370370
return self.parts[: len(other.parts)] == other.parts[: len(self.parts)]
371371

372-
def lineage(self):
372+
def lineage(self) -> set["FieldPath"]:
373373
"""Return field paths for all parents.
374374
375375
Returns: Set[:class:`FieldPath`]
@@ -378,7 +378,7 @@ def lineage(self):
378378
return {FieldPath(*self.parts[:index]) for index in indexes}
379379

380380
@staticmethod
381-
def document_id():
381+
def document_id() -> str:
382382
"""A special FieldPath value to refer to the ID of a document. It can be used
383383
in queries to sort or filter by the document ID.
384384

google/cloud/firestore_v1/watch.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
import collections
1617
import functools
@@ -232,7 +233,7 @@ def __init__(
232233
def _init_stream(self):
233234
rpc_request = self._get_rpc_request
234235

235-
self._rpc = ResumableBidiRpc(
236+
self._rpc: ResumableBidiRpc | None = ResumableBidiRpc(
236237
start_rpc=self._api._transport.listen,
237238
should_recover=_should_recover,
238239
should_terminate=_should_terminate,
@@ -243,7 +244,9 @@ def _init_stream(self):
243244
self._rpc.add_done_callback(self._on_rpc_done)
244245

245246
# The server assigns and updates the resume token.
246-
self._consumer = BackgroundConsumer(self._rpc, self.on_snapshot)
247+
self._consumer: BackgroundConsumer | None = BackgroundConsumer(
248+
self._rpc, self.on_snapshot
249+
)
247250
self._consumer.start()
248251

249252
@classmethod
@@ -330,16 +333,18 @@ def close(self, reason=None):
330333
return
331334

332335
# Stop consuming messages.
333-
if self.is_active:
334-
_LOGGER.debug("Stopping consumer.")
335-
self._consumer.stop()
336-
self._consumer._on_response = None
336+
if self._consumer:
337+
if self.is_active:
338+
_LOGGER.debug("Stopping consumer.")
339+
self._consumer.stop()
340+
self._consumer._on_response = None
337341
self._consumer = None
338342

339343
self._snapshot_callback = None
340-
self._rpc.close()
341-
self._rpc._initial_request = None
342-
self._rpc._callbacks = []
344+
if self._rpc:
345+
self._rpc.close()
346+
self._rpc._initial_request = None
347+
self._rpc._callbacks = []
343348
self._rpc = None
344349
self._closed = True
345350
_LOGGER.debug("Finished stopping manager.")
@@ -460,13 +465,13 @@ def on_snapshot(self, proto):
460465
message = f"Unknown target change type: {target_change_type}"
461466
_LOGGER.info(f"on_snapshot: {message}")
462467
self.close(reason=ValueError(message))
463-
464-
try:
465-
# Use 'proto' vs 'pb' for datetime handling
466-
meth(self, proto.target_change)
467-
except Exception as exc2:
468-
_LOGGER.debug(f"meth(proto) exc: {exc2}")
469-
raise
468+
else:
469+
try:
470+
# Use 'proto' vs 'pb' for datetime handling
471+
meth(self, proto.target_change)
472+
except Exception as exc2:
473+
_LOGGER.debug(f"meth(proto) exc: {exc2}")
474+
raise
470475

471476
# NOTE:
472477
# in other implementations, such as node, the backoff is reset here

0 commit comments

Comments
 (0)