Skip to content
Merged
41 changes: 28 additions & 13 deletions google/cloud/firestore_v1/async_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
"""Helpers for applying Google Cloud Firestore changes in a transaction."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Optional
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Awaitable,
Callable,
Optional,
)

from google.api_core import exceptions, gapic_v1
from google.api_core import retry_async as retries
Expand All @@ -37,11 +44,15 @@
# Types needed only for Type Hints
if TYPE_CHECKING: # pragma: NO COVER
import datetime
from typing_extensions import TypeVar, ParamSpec, Concatenate

from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.query_profile import ExplainOptions

T = TypeVar("T")
P = ParamSpec("P")


class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction):
"""Accumulate read-and-write operations to be sent in a transaction.
Expand Down Expand Up @@ -253,12 +264,14 @@ class _AsyncTransactional(_BaseTransactional):
A coroutine that should be run (and retried) in a transaction.
"""

def __init__(self, to_wrap) -> None:
def __init__(
self, to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]
) -> None:
super(_AsyncTransactional, self).__init__(to_wrap)

async def _pre_commit(
self, transaction: AsyncTransaction, *args, **kwargs
) -> Coroutine:
self, transaction: AsyncTransaction, *args: P.args, **kwargs: P.kwargs
) -> T:
"""Begin transaction and call the wrapped coroutine.

Args:
Expand All @@ -271,7 +284,7 @@ async def _pre_commit(
along to the wrapped coroutine.

Returns:
Any: result of the wrapped coroutine.
T: result of the wrapped coroutine.

Raises:
Exception: Any failure caused by ``to_wrap``.
Expand All @@ -286,20 +299,22 @@ async def _pre_commit(
self.retry_id = self.current_id
return await self.to_wrap(transaction, *args, **kwargs)

async def __call__(self, transaction, *args, **kwargs):
async def __call__(
self, transaction: AsyncTransaction, *args: P.args, **kwargs: P.kwargs
) -> T:
"""Execute the wrapped callable within a transaction.

Args:
transaction
(:class:`~google.cloud.firestore_v1.transaction.Transaction`):
(:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`):
A transaction to execute the callable within.
args (Tuple[Any, ...]): The extra positional arguments to pass
along to the wrapped callable.
kwargs (Dict[str, Any]): The extra keyword arguments to pass
along to the wrapped callable.

Returns:
Any: The result of the wrapped callable.
T: The result of the wrapped callable.

Raises:
ValueError: If the transaction does not succeed in
Expand All @@ -313,7 +328,7 @@ async def __call__(self, transaction, *args, **kwargs):

try:
for attempt in range(transaction._max_attempts):
result = await self._pre_commit(transaction, *args, **kwargs)
result: T = await self._pre_commit(transaction, *args, **kwargs)
try:
await transaction._commit()
return result
Expand All @@ -338,17 +353,17 @@ async def __call__(self, transaction, *args, **kwargs):


def async_transactional(
to_wrap: Callable[[AsyncTransaction], Any]
) -> _AsyncTransactional:
to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]
) -> Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]:
"""Decorate a callable so that it runs in a transaction.

Args:
to_wrap
(Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]):
(Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Awaitable[Any]]):
A callable that should be run (and retried) in a transaction.

Returns:
Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]:
Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Awaitable[Any]]:
the wrapped callable.
"""
return _AsyncTransactional(to_wrap)
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[mypy]
python_version = 3.6
python_version = 3.8
namespace_packages = True
Loading