Skip to content

Commit a6f20b9

Browse files
authored
LRO continuation_token (#10801)
* LRO continuation_token * from cont token is a clsmethod * Add async ABC for cont token * Pickle and azure-core pipeline * Make a aiohttp response pickable, but loosing the internal response * Add AsyncLROPoller * mypy * mpylint * Continuation token are optional abstract methods * Add async status * mypy * base64 the continuation token to be a string and not bytes * Typo * Tests and new AsyncPoller * Fix mypy * Fix tests for Python 2.7 * More tests * Add more tests, including asyncio_ensure_future wrapper
1 parent f88d011 commit a6f20b9

File tree

10 files changed

+558
-35
lines changed

10 files changed

+558
-35
lines changed

sdk/core/azure-core/azure/core/pipeline/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,21 @@ def __init__(self, transport, **kwargs): # pylint: disable=super-init-not-calle
6969
self.options = kwargs
7070
self._protected = ["transport", "options"]
7171

72+
def __getstate__(self):
73+
state = self.__dict__.copy()
74+
# Remove the unpicklable entries.
75+
del state['transport']
76+
return state
77+
78+
def __setstate__(self, state):
79+
self.__dict__.update(state)
80+
# Re-create the unpickable entries
81+
self.transport = None
82+
7283
def __setitem__(self, key, item):
73-
if key in self._protected:
84+
# If reloaded from pickle, _protected might not be here until restored by pickle
85+
# this explains the hasattr test
86+
if hasattr(self, '_protected') and key in self._protected:
7487
raise ValueError("Context value {} cannot be overwritten.".format(key))
7588
return super(PipelineContext, self).__setitem__(key, item)
7689

sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,14 @@ def stream_download(self, pipeline) -> AsyncIteratorType[bytes]:
299299
:type pipeline: azure.core.pipeline
300300
"""
301301
return AioHttpStreamDownloadGenerator(pipeline, self)
302+
303+
def __getstate__(self):
304+
# Be sure body is loaded in memory, otherwise not pickable and let it throw
305+
self.body()
306+
307+
state = self.__dict__.copy()
308+
# Remove the unpicklable entries.
309+
state['internal_response'] = None # aiohttp response are not pickable (see headers comments)
310+
from multidict import MultiDict # I know it's importable since aiohttp is loaded
311+
state['headers'] = MultiDict(self.headers) # MultiDictProxy is not pickable
312+
return state

sdk/core/azure-core/azure/core/polling/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@
3131
#pylint: disable=unused-import
3232
if sys.version_info >= (3, 5, 2):
3333
# Not executed on old Python, no syntax error
34-
from ._async_poller import AsyncNoPolling, AsyncPollingMethod, async_poller
35-
__all__ += ['AsyncNoPolling', 'AsyncPollingMethod', 'async_poller']
34+
from ._async_poller import AsyncNoPolling, AsyncPollingMethod, async_poller, AsyncLROPoller
35+
__all__ += ['AsyncNoPolling', 'AsyncPollingMethod', 'async_poller', 'AsyncLROPoller']

sdk/core/azure-core/azure/core/polling/_async_poller.py

Lines changed: 115 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
# IN THE SOFTWARE.
2424
#
2525
# --------------------------------------------------------------------------
26-
from typing import Generic, TypeVar, Any
26+
from collections.abc import Awaitable
27+
from typing import Callable, Any, Tuple, Generic, TypeVar, Generator
28+
2729
from ._poller import NoPolling as _NoPolling
2830

2931

@@ -48,6 +50,21 @@ def finished(self) -> bool:
4850
def resource(self) -> PollingReturnType:
4951
raise NotImplementedError("This method needs to be implemented")
5052

53+
def get_continuation_token(self) -> str:
54+
raise TypeError(
55+
"Polling method '{}' doesn't support get_continuation_token".format(
56+
self.__class__.__name__
57+
)
58+
)
59+
60+
@classmethod
61+
def from_continuation_token(cls, continuation_token: str, **kwargs) -> Tuple[Any, Any, Callable]:
62+
raise TypeError(
63+
"Polling method '{}' doesn't support from_continuation_token".format(
64+
cls.__name__
65+
)
66+
)
67+
5168

5269
class AsyncNoPolling(_NoPolling):
5370
"""An empty async poller that returns the deserialized initial response.
@@ -61,6 +78,9 @@ async def run(self):
6178
async def async_poller(client, initial_response, deserialization_callback, polling_method):
6279
"""Async Poller for long running operations.
6380
81+
.. deprecated:: 1.5.0
82+
Use :class:`AsyncLROPoller` instead.
83+
6484
:param client: A pipeline service client.
6585
:type client: ~azure.core.PipelineClient
6686
:param initial_response: The initial call response
@@ -71,15 +91,100 @@ async def async_poller(client, initial_response, deserialization_callback, polli
7191
:param polling_method: The polling strategy to adopt
7292
:type polling_method: ~azure.core.polling.PollingMethod
7393
"""
94+
poller = AsyncLROPoller(client, initial_response, deserialization_callback, polling_method)
95+
return await poller
96+
97+
98+
class AsyncLROPoller(Awaitable, Generic[PollingReturnType]):
99+
"""Async poller for long running operations.
100+
101+
:param client: A pipeline service client
102+
:type client: ~azure.core.PipelineClient
103+
:param initial_response: The initial call response
104+
:type initial_response:
105+
~azure.core.pipeline.transport.HttpResponse or ~azure.core.pipeline.transport.AsyncHttpResponse
106+
:param deserialization_callback: A callback that takes a Response and return a deserialized object.
107+
If a subclass of Model is given, this passes "deserialize" as callback.
108+
:type deserialization_callback: callable or msrest.serialization.Model
109+
:param polling_method: The polling strategy to adopt
110+
:type polling_method: ~azure.core.polling.AsyncPollingMethod
111+
"""
112+
113+
def __init__(
114+
self,
115+
client: Any,
116+
initial_response: Any,
117+
deserialization_callback: Callable,
118+
polling_method: AsyncPollingMethod[PollingReturnType]
119+
):
120+
self._polling_method = polling_method
121+
self._done = False
122+
123+
# This implicit test avoids bringing in an explicit dependency on Model directly
124+
try:
125+
deserialization_callback = deserialization_callback.deserialize # type: ignore
126+
except AttributeError:
127+
pass
128+
129+
self._polling_method.initialize(client, initial_response, deserialization_callback)
130+
131+
def polling_method(self) -> AsyncPollingMethod[PollingReturnType]:
132+
"""Return the polling method associated to this poller.
133+
"""
134+
return self._polling_method
135+
136+
def continuation_token(self) -> str:
137+
"""Return a continuation token that allows to restart the poller later.
138+
139+
:returns: An opaque continuation token
140+
:rtype: str
141+
"""
142+
return self._polling_method.get_continuation_token()
143+
144+
@classmethod
145+
def from_continuation_token(
146+
cls,
147+
polling_method: AsyncPollingMethod[PollingReturnType],
148+
continuation_token: str,
149+
**kwargs
150+
) -> "AsyncLROPoller[PollingReturnType]":
151+
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
152+
continuation_token, **kwargs
153+
)
154+
return cls(client, initial_response, deserialization_callback, polling_method)
74155

75-
# This implicit test avoids bringing in an explicit dependency on Model directly
76-
try:
77-
deserialization_callback = deserialization_callback.deserialize
78-
except AttributeError:
79-
pass
156+
def status(self) -> str:
157+
"""Returns the current status string.
158+
159+
:returns: The current status string
160+
:rtype: str
161+
"""
162+
return self._polling_method.status()
80163

81-
# Might raise a CloudError
82-
polling_method.initialize(client, initial_response, deserialization_callback)
164+
async def result(self) -> PollingReturnType:
165+
"""Return the result of the long running operation.
83166
84-
await polling_method.run()
85-
return polling_method.resource()
167+
:returns: The deserialized resource of the long running operation, if one is available.
168+
:raises ~azure.core.exceptions.HttpResponseError: Server problem with the query.
169+
"""
170+
await self.wait()
171+
return self._polling_method.resource()
172+
173+
def __await__(self) -> Generator[Any, None, PollingReturnType]:
174+
return self.result().__await__()
175+
176+
async def wait(self) -> None:
177+
"""Wait on the long running operation.
178+
179+
:raises ~azure.core.exceptions.HttpResponseError: Server problem with the query.
180+
"""
181+
await self._polling_method.run()
182+
self._done = True
183+
184+
def done(self) -> bool:
185+
"""Check status of the long running operation.
186+
187+
:returns: 'True' if the process has completed, else 'False'.
188+
:rtype: bool
189+
"""
190+
return self._done

sdk/core/azure-core/azure/core/polling/_poller.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,22 @@
2323
# IN THE SOFTWARE.
2424
#
2525
# --------------------------------------------------------------------------
26+
import base64
2627
import threading
2728
import uuid
2829
try:
2930
from urlparse import urlparse # type: ignore # pylint: disable=unused-import
3031
except ImportError:
3132
from urllib.parse import urlparse
3233

33-
from typing import Any, Callable, Union, List, Optional, TypeVar, Generic, TYPE_CHECKING
34+
from typing import Any, Callable, Union, List, Optional, Tuple, TypeVar, Generic
3435
from azure.core.pipeline.transport._base import HttpResponse
3536
from azure.core.tracing.decorator import distributed_trace
3637
from azure.core.tracing.common import with_current_context
3738

38-
if TYPE_CHECKING:
39-
import requests
40-
from msrest.serialization import Model # pylint: disable=unused-import
41-
DeserializationCallbackType = Union[Model, Callable[[requests.Response], Model]]
4239
PollingReturnType = TypeVar("PollingReturnType")
4340

41+
4442
class PollingMethod(Generic[PollingReturnType]):
4543
"""ABC class for polling method.
4644
"""
@@ -64,6 +62,24 @@ def resource(self):
6462
# type: () -> PollingReturnType
6563
raise NotImplementedError("This method needs to be implemented")
6664

65+
def get_continuation_token(self):
66+
# type() -> str
67+
raise TypeError(
68+
"Polling method '{}' doesn't support get_continuation_token".format(
69+
self.__class__.__name__
70+
)
71+
)
72+
73+
@classmethod
74+
def from_continuation_token(cls, continuation_token, **kwargs):
75+
# type(str, Any) -> Tuple[Any, Any, Callable]
76+
raise TypeError(
77+
"Polling method '{}' doesn't support from_continuation_token".format(
78+
cls.__name__
79+
)
80+
)
81+
82+
6783
class NoPolling(PollingMethod):
6884
"""An empty poller that returns the deserialized initial response.
6985
"""
@@ -72,7 +88,7 @@ def __init__(self):
7288
self._deserialization_callback = None
7389

7490
def initialize(self, _, initial_response, deserialization_callback):
75-
# type: (Any, requests.Response, Callable) -> None
91+
# type: (Any, Any, Callable) -> None
7692
self._initial_response = initial_response
7793
self._deserialization_callback = deserialization_callback
7894

@@ -101,6 +117,22 @@ def resource(self):
101117
# type: () -> Any
102118
return self._deserialization_callback(self._initial_response)
103119

120+
def get_continuation_token(self):
121+
# type() -> str
122+
import pickle
123+
return base64.b64encode(pickle.dumps(self._initial_response)).decode('ascii')
124+
125+
@classmethod
126+
def from_continuation_token(cls, continuation_token, **kwargs):
127+
# type(str, Any) -> Tuple
128+
try:
129+
deserialization_callback = kwargs["deserialization_callback"]
130+
except KeyError:
131+
raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token")
132+
import pickle
133+
initial_response = pickle.loads(base64.b64decode(continuation_token))
134+
return None, initial_response, deserialization_callback
135+
104136

105137
class LROPoller(Generic[PollingReturnType]):
106138
"""Poller for long running operations.
@@ -118,9 +150,7 @@ class LROPoller(Generic[PollingReturnType]):
118150
"""
119151

120152
def __init__(self, client, initial_response, deserialization_callback, polling_method):
121-
# type: (Any, HttpResponse, DeserializationCallbackType, PollingMethod) -> None
122-
self._client = client
123-
self._response = initial_response
153+
# type: (Any, HttpResponse, Callable, PollingMethod[PollingReturnType]) -> None
124154
self._callbacks = [] # type: List[Callable]
125155
self._polling_method = polling_method
126156

@@ -131,7 +161,7 @@ def __init__(self, client, initial_response, deserialization_callback, polling_m
131161
pass
132162

133163
# Might raise a CloudError
134-
self._polling_method.initialize(self._client, self._response, deserialization_callback)
164+
self._polling_method.initialize(client, initial_response, deserialization_callback)
135165

136166
# Prepare thread execution
137167
self._thread = None
@@ -166,6 +196,29 @@ def _start(self):
166196
call(self._polling_method)
167197
callbacks, self._callbacks = self._callbacks, []
168198

199+
def polling_method(self):
200+
# type: () -> PollingMethod[PollingReturnType]
201+
"""Return the polling method associated to this poller.
202+
"""
203+
return self._polling_method
204+
205+
def continuation_token(self):
206+
# type: () -> str
207+
"""Return a continuation token that allows to restart the poller later.
208+
209+
:returns: An opaque continuation token
210+
:rtype: str
211+
"""
212+
return self._polling_method.get_continuation_token()
213+
214+
@classmethod
215+
def from_continuation_token(cls, polling_method, continuation_token, **kwargs):
216+
# type: (PollingMethod[PollingReturnType], str, Any) -> LROPoller[PollingReturnType]
217+
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
218+
continuation_token, **kwargs
219+
)
220+
return cls(client, initial_response, deserialization_callback, polling_method)
221+
169222
def status(self):
170223
# type: () -> str
171224
"""Returns the current status string.

sdk/core/azure-core/azure/core/polling/base_polling.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#
2525
# --------------------------------------------------------------------------
2626
import abc
27+
import base64
2728
import json
2829
from typing import TYPE_CHECKING, Optional, Any, Union
2930

@@ -447,6 +448,30 @@ def initialize(self, client, initial_response, deserialization_callback):
447448
except OperationFailed as err:
448449
raise HttpResponseError(response=initial_response.http_response, error=err)
449450

451+
def get_continuation_token(self):
452+
# type() -> str
453+
import pickle
454+
return base64.b64encode(pickle.dumps(self._initial_response)).decode('ascii')
455+
456+
@classmethod
457+
def from_continuation_token(cls, continuation_token, **kwargs):
458+
# type(str, Any) -> Tuple
459+
try:
460+
client = kwargs["client"]
461+
except KeyError:
462+
raise ValueError("Need kwarg 'client' to be recreated from continuation_token")
463+
464+
try:
465+
deserialization_callback = kwargs["deserialization_callback"]
466+
except KeyError:
467+
raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token")
468+
469+
import pickle
470+
initial_response = pickle.loads(base64.b64decode(continuation_token))
471+
# Restore the transport in the context
472+
initial_response.context.transport = client._pipeline._transport # pylint: disable=protected-access
473+
return client, initial_response, deserialization_callback
474+
450475
def run(self):
451476
try:
452477
self._poll()

0 commit comments

Comments
 (0)