Skip to content

Commit 078f660

Browse files
committed
add AsyncAPIClient
1 parent 522700a commit 078f660

File tree

11 files changed

+309
-67
lines changed

11 files changed

+309
-67
lines changed

apiclient/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
NoAuthentication,
66
QueryParameterAuthentication,
77
)
8-
from apiclient.client import APIClient
8+
from apiclient.client import AbstractClient, APIClient, AsyncAPIClient
99
from apiclient.decorates import endpoint
1010
from apiclient.paginators import paginated
1111
from apiclient.request_formatters import JsonRequestFormatter

apiclient/authentication_methods.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
if TYPE_CHECKING: # pragma: no cover
77
# Stupid way of getting around cyclic imports when
88
# using typehinting.
9-
from apiclient import APIClient
9+
from apiclient.client import AbstractClient
1010

1111

1212
class BaseAuthenticationMethod:
@@ -19,7 +19,7 @@ def get_query_params(self) -> dict:
1919
def get_username_password_authentication(self) -> Optional[BasicAuthType]:
2020
return None
2121

22-
def perform_initial_auth(self, client: "APIClient"):
22+
def perform_initial_auth(self, client: "AbstractClient"):
2323
pass
2424

2525

@@ -91,7 +91,7 @@ def __init__(
9191
self._auth_url = auth_url
9292
self._authentication = authentication
9393

94-
def perform_initial_auth(self, client: "APIClient"):
94+
def perform_initial_auth(self, client: "AbstractClient"):
9595
client.get(
9696
self._auth_url,
9797
headers=self._authentication.get_headers(),

apiclient/client.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from apiclient.authentication_methods import BaseAuthenticationMethod, NoAuthentication
66
from apiclient.error_handlers import BaseErrorHandler, ErrorHandler
77
from apiclient.request_formatters import BaseRequestFormatter, NoOpRequestFormatter
8-
from apiclient.request_strategies import BaseRequestStrategy, RequestStrategy
8+
from apiclient.request_strategies import AsyncRequestStrategy, BaseRequestStrategy, RequestStrategy
99
from apiclient.response_handlers import BaseResponseHandler, RequestsResponseHandler
1010
from apiclient.utils.typing import OptionalDict
1111

@@ -15,7 +15,7 @@
1515
DEFAULT_TIMEOUT = 10.0
1616

1717

18-
class APIClient:
18+
class AbstractClient:
1919
def __init__(
2020
self,
2121
authentication_method: Optional[BaseAuthenticationMethod] = None,
@@ -37,11 +37,14 @@ def __init__(
3737
self.set_response_handler(response_handler)
3838
self.set_error_handler(error_handler)
3939
self.set_request_formatter(request_formatter)
40-
self.set_request_strategy(request_strategy or RequestStrategy())
40+
self.set_request_strategy(request_strategy or self.get_default_request_strategy())
4141

4242
# Perform any one time authentication required by api
4343
self._authentication_method.perform_initial_auth(self)
4444

45+
def get_default_request_strategy(self): # pragma: no cover
46+
raise NotImplementedError
47+
4548
def get_session(self) -> Any:
4649
return self._session
4750

@@ -135,3 +138,44 @@ def delete(self, endpoint: str, params: OptionalDict = None, **kwargs):
135138
"""Remove resource with DELETE endpoint."""
136139
LOG.debug("DELETE %s", endpoint)
137140
return self.get_request_strategy().delete(endpoint, params=params, **kwargs)
141+
142+
143+
class APIClient(AbstractClient):
144+
def get_default_request_strategy(self):
145+
return RequestStrategy()
146+
147+
148+
class AsyncAPIClient(AbstractClient):
149+
async def __aenter__(self):
150+
session = await self._request_strategy.create_session()
151+
self.set_session(session)
152+
return self
153+
154+
async def __aexit__(self, exc_type, exc_value, traceback):
155+
session = self.get_session()
156+
if session:
157+
await session.close()
158+
self.set_session(None)
159+
160+
def get_default_request_strategy(self):
161+
return AsyncRequestStrategy()
162+
163+
async def post(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
164+
"""Send data and return response data from POST endpoint."""
165+
return await self.get_request_strategy().post(endpoint, data=data, params=params, **kwargs)
166+
167+
async def get(self, endpoint: str, params: OptionalDict = None, **kwargs):
168+
"""Return response data from GET endpoint."""
169+
return await self.get_request_strategy().get(endpoint, params=params, **kwargs)
170+
171+
async def put(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
172+
"""Send data to overwrite resource and return response data from PUT endpoint."""
173+
return await self.get_request_strategy().put(endpoint, data=data, params=params, **kwargs)
174+
175+
async def patch(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
176+
"""Send data to update resource and return response data from PATCH endpoint."""
177+
return await self.get_request_strategy().patch(endpoint, data=data, params=params, **kwargs)
178+
179+
async def delete(self, endpoint: str, params: OptionalDict = None, **kwargs):
180+
"""Remove resource with DELETE endpoint."""
181+
return await self.get_request_strategy().delete(endpoint, params=params, **kwargs)

apiclient/request_strategies.py

Lines changed: 115 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,99 @@
11
from copy import deepcopy
2-
from typing import TYPE_CHECKING, Callable
2+
from typing import TYPE_CHECKING, Any, Callable
33

4+
import aiohttp
45
import requests
56

67
from apiclient.exceptions import UnexpectedError
7-
from apiclient.response import RequestsResponse, Response
8+
from apiclient.response import AioHttpResponse, RequestsResponse, Response
89
from apiclient.utils.typing import OptionalDict
910

1011
if TYPE_CHECKING: # pragma: no cover
1112
# Stupid way of getting around cyclic imports when
1213
# using typehinting.
13-
from apiclient import APIClient
14+
from apiclient.client import AbstractClient
1415

1516

1617
class BaseRequestStrategy:
17-
def set_client(self, client: "APIClient"):
18+
def set_client(self, client: "AbstractClient"):
1819
self._client = client
1920

20-
def get_client(self) -> "APIClient":
21+
def get_session(self):
22+
return self.get_client().get_session()
23+
24+
def set_session(self, session: Any):
25+
self.get_client().set_session(session)
26+
27+
def create_session(self): # pragma: no cover
28+
"""Abstract method that will create a session object."""
29+
raise NotImplementedError
30+
31+
def get_client(self) -> "AbstractClient":
2132
return self._client
2233

23-
def post(self, *args, **kwargs): # pragma: no cover
34+
def _get_request_params(self, params: OptionalDict) -> dict:
35+
"""Return dictionary with any additional authentication query parameters."""
36+
if params is None:
37+
params = {}
38+
params.update(self.get_client().get_default_query_params())
39+
return params
40+
41+
def _get_request_headers(self, headers: OptionalDict) -> dict:
42+
"""Return dictionary with any additional authentication headers."""
43+
if headers is None:
44+
headers = {}
45+
headers.update(self.get_client().get_default_headers())
46+
return headers
47+
48+
def _get_username_password_authentication(self):
49+
return self.get_client().get_default_username_password_authentication()
50+
51+
def _get_formatted_data(self, data: OptionalDict):
52+
return self.get_client().get_request_formatter().format(data)
53+
54+
def _get_request_timeout(self) -> float:
55+
"""Return the number of seconds before the request times out."""
56+
return self.get_client().get_request_timeout()
57+
58+
def _check_response(self, response: Response):
59+
"""Raise a custom exception if the response is not OK."""
60+
status_code = response.get_status_code()
61+
if status_code < 200 or status_code >= 300:
62+
self._handle_bad_response(response)
63+
64+
def _decode_response_data(self, response: Response):
65+
return self.get_client().get_response_handler().get_request_data(response)
66+
67+
def _handle_bad_response(self, response: Response):
68+
"""Convert the error into an understandable client exception."""
69+
raise self.get_client().get_error_handler().get_exception(response)
70+
71+
def post(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): # pragma: no cover
2472
raise NotImplementedError
2573

26-
def get(self, *args, **kwargs): # pragma: no cover
74+
def get(self, endpoint: str, params: OptionalDict = None, **kwargs): # pragma: no cover
2775
raise NotImplementedError
2876

29-
def put(self, *args, **kwargs): # pragma: no cover
77+
def put(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): # pragma: no cover
3078
raise NotImplementedError
3179

32-
def patch(self, *args, **kwargs): # pragma: no cover
80+
def patch(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): # pragma: no cover
3381
raise NotImplementedError
3482

35-
def delete(self, *args, **kwargs): # pragma: no cover
83+
def delete(self, endpoint: str, params: OptionalDict = None, **kwargs): # pragma: no cover
3684
raise NotImplementedError
3785

3886

3987
class RequestStrategy(BaseRequestStrategy):
4088
"""Requests strategy that uses the `requests` lib with a `requests.session`."""
4189

42-
def set_client(self, client: "APIClient"):
90+
def set_client(self, client: "AbstractClient"):
4391
super().set_client(client)
44-
# Set a global `requests.session` on the parent client instance.
4592
if self.get_session() is None:
46-
self.set_session(requests.session())
93+
self.set_session(self.create_session())
4794

48-
def get_session(self):
49-
return self.get_client().get_session()
50-
51-
def set_session(self, session: requests.Session):
52-
self.get_client().set_session(session)
95+
def create_session(self) -> requests.Session:
96+
return requests.session()
5397

5498
def post(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
5599
"""Send data and return response data from POST endpoint."""
@@ -102,43 +146,6 @@ def _make_request(
102146
self._check_response(response)
103147
return self._decode_response_data(response)
104148

105-
def _get_request_params(self, params: OptionalDict) -> dict:
106-
"""Return dictionary with any additional authentication query parameters."""
107-
if params is None:
108-
params = {}
109-
params.update(self.get_client().get_default_query_params())
110-
return params
111-
112-
def _get_request_headers(self, headers: OptionalDict) -> dict:
113-
"""Return dictionary with any additional authentication headers."""
114-
if headers is None:
115-
headers = {}
116-
headers.update(self.get_client().get_default_headers())
117-
return headers
118-
119-
def _get_username_password_authentication(self):
120-
return self.get_client().get_default_username_password_authentication()
121-
122-
def _get_formatted_data(self, data: OptionalDict):
123-
return self.get_client().get_request_formatter().format(data)
124-
125-
def _get_request_timeout(self) -> float:
126-
"""Return the number of seconds before the request times out."""
127-
return self.get_client().get_request_timeout()
128-
129-
def _check_response(self, response: Response):
130-
"""Raise a custom exception if the response is not OK."""
131-
status_code = response.get_status_code()
132-
if status_code < 200 or status_code >= 300:
133-
self._handle_bad_response(response)
134-
135-
def _decode_response_data(self, response: Response):
136-
return self.get_client().get_response_handler().get_request_data(response)
137-
138-
def _handle_bad_response(self, response: Response):
139-
"""Convert the error into an understandable client exception."""
140-
raise self.get_client().get_error_handler().get_exception(response)
141-
142149

143150
class QueryParamPaginatedRequestStrategy(RequestStrategy):
144151
"""Strategy for GET requests where pages are defined in query params."""
@@ -192,3 +199,56 @@ def get(self, endpoint: str, params: OptionalDict = None, **kwargs):
192199

193200
def get_next_page_url(self, response, previous_page_url: str) -> OptionalDict:
194201
return self._next_page(response, previous_page_url)
202+
203+
204+
class AsyncRequestStrategy(BaseRequestStrategy):
205+
async def create_session(self) -> aiohttp.ClientSession:
206+
return aiohttp.ClientSession()
207+
208+
def get_session(self) -> aiohttp.ClientSession:
209+
return self.get_client().get_session()
210+
211+
async def post(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
212+
return await self._make_request(
213+
self.get_session().post, endpoint, data=data, params=params, **kwargs
214+
)
215+
216+
async def get(self, endpoint: str, params: OptionalDict = None, **kwargs):
217+
return await self._make_request(self.get_session().get, endpoint, params=params, **kwargs)
218+
219+
async def put(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
220+
return await self._make_request(self.get_session().put, endpoint, data=data, params=params, **kwargs)
221+
222+
async def patch(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
223+
return await self._make_request(
224+
self.get_session().patch, endpoint, data=data, params=params, **kwargs
225+
)
226+
227+
async def delete(self, endpoint: str, params: OptionalDict = None, **kwargs):
228+
return await self._make_request(self.get_session().delete, endpoint, params=params, **kwargs)
229+
230+
async def _make_request(
231+
self,
232+
request_method: Callable,
233+
endpoint: str,
234+
params: OptionalDict = None,
235+
headers: OptionalDict = None,
236+
data: OptionalDict = None,
237+
**kwargs,
238+
) -> Response:
239+
try:
240+
async with request_method(
241+
endpoint,
242+
params=self._get_request_params(params),
243+
headers=self._get_request_headers(headers),
244+
auth=self._get_username_password_authentication(),
245+
data=self._get_formatted_data(data),
246+
timeout=self._get_request_timeout(),
247+
**kwargs,
248+
) as raw_response:
249+
response = AioHttpResponse(raw_response, content=await raw_response.read())
250+
except Exception as error:
251+
raise UnexpectedError(f"Error when contacting '{endpoint}'") from error
252+
else:
253+
self._check_response(response)
254+
return self._decode_response_data(response)

apiclient/response.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import json
12
from typing import Any
23

4+
import aiohttp
35
import requests
46

57
from apiclient.utils.typing import JsonType
@@ -62,3 +64,26 @@ def get_status_reason(self) -> str:
6264

6365
def get_requested_url(self) -> str:
6466
return self._response.url
67+
68+
69+
class AioHttpResponse(RequestsResponse):
70+
"""Implementation of the response for a requests.response type."""
71+
72+
def __init__(self, response: aiohttp.ClientResponse, content: bytes):
73+
self._response = response
74+
self._content = content
75+
self._text = ""
76+
77+
def get_status_code(self) -> int:
78+
return self._response.status
79+
80+
def get_raw_data(self) -> str:
81+
if not self._text:
82+
self._text = self._content.decode(self._response.get_encoding(), errors="strict")
83+
return self._text
84+
85+
def get_json(self) -> JsonType:
86+
return json.loads(self._text)
87+
88+
def get_requested_url(self) -> str:
89+
return str(self._response.url)

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[tool:pytest]
2-
addopts = --cov=apiclient/ --cov-fail-under=100 --cov-report html
2+
addopts = --asyncio-mode=auto --cov=apiclient/ --cov-fail-under=100 --cov-report html
33
env =
44
ENDPOINT_BASE_URL=http://environment.com
55

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import setuptools
44

55
# Pinning tenacity as the api has changed slightly which breaks all tests.
6-
application_dependencies = ["requests>=2.16", "tenacity>=5.1.0"]
6+
application_dependencies = ["requests>=2.16", "aiohttp>=3.8", "tenacity>=5.1.0"]
77
prod_dependencies = []
8-
test_dependencies = ["pytest", "pytest-env", "pytest-cov", "vcrpy", "requests-mock"]
8+
test_dependencies = ["pytest", "pytest-env", "pytest-cov", "vcrpy", "requests-mock", "pytest-asyncio", "aioresponses"]
99
lint_dependencies = ["flake8", "flake8-docstrings", "black", "isort"]
1010
docs_dependencies = []
1111
dev_dependencies = test_dependencies + lint_dependencies + docs_dependencies + ["ipdb"]

0 commit comments

Comments
 (0)