Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apiclient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
NoAuthentication,
QueryParameterAuthentication,
)
from apiclient.client import APIClient
from apiclient.client import AbstractClient, APIClient, AsyncAPIClient
from apiclient.decorates import endpoint
from apiclient.paginators import paginated
from apiclient.request_formatters import JsonRequestFormatter
Expand Down
6 changes: 3 additions & 3 deletions apiclient/authentication_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING: # pragma: no cover
# Stupid way of getting around cyclic imports when
# using typehinting.
from apiclient import APIClient
from apiclient.client import AbstractClient


class BaseAuthenticationMethod:
Expand All @@ -19,7 +19,7 @@ def get_query_params(self) -> dict:
def get_username_password_authentication(self) -> Optional[BasicAuthType]:
return None

def perform_initial_auth(self, client: "APIClient"):
def perform_initial_auth(self, client: "AbstractClient"):
pass


Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(
self._auth_url = auth_url
self._authentication = authentication

def perform_initial_auth(self, client: "APIClient"):
def perform_initial_auth(self, client: "AbstractClient"):
client.get(
self._auth_url,
headers=self._authentication.get_headers(),
Expand Down
50 changes: 47 additions & 3 deletions apiclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from apiclient.authentication_methods import BaseAuthenticationMethod, NoAuthentication
from apiclient.error_handlers import BaseErrorHandler, ErrorHandler
from apiclient.request_formatters import BaseRequestFormatter, NoOpRequestFormatter
from apiclient.request_strategies import BaseRequestStrategy, RequestStrategy
from apiclient.request_strategies import AsyncRequestStrategy, BaseRequestStrategy, RequestStrategy
from apiclient.response_handlers import BaseResponseHandler, RequestsResponseHandler
from apiclient.utils.typing import OptionalDict

Expand All @@ -15,7 +15,7 @@
DEFAULT_TIMEOUT = 10.0


class APIClient:
class AbstractClient:
def __init__(
self,
authentication_method: Optional[BaseAuthenticationMethod] = None,
Expand All @@ -37,11 +37,14 @@ def __init__(
self.set_response_handler(response_handler)
self.set_error_handler(error_handler)
self.set_request_formatter(request_formatter)
self.set_request_strategy(request_strategy or RequestStrategy())
self.set_request_strategy(request_strategy or self.get_default_request_strategy())

# Perform any one time authentication required by api
self._authentication_method.perform_initial_auth(self)

def get_default_request_strategy(self): # pragma: no cover
raise NotImplementedError

def get_session(self) -> Any:
return self._session

Expand Down Expand Up @@ -135,3 +138,44 @@ def delete(self, endpoint: str, params: OptionalDict = None, **kwargs):
"""Remove resource with DELETE endpoint."""
LOG.debug("DELETE %s", endpoint)
return self.get_request_strategy().delete(endpoint, params=params, **kwargs)


class APIClient(AbstractClient):
def get_default_request_strategy(self):
return RequestStrategy()


class AsyncAPIClient(AbstractClient):
async def __aenter__(self):
session = await self._request_strategy.create_session()
self.set_session(session)
return self

async def __aexit__(self, exc_type, exc_value, traceback):
session = self.get_session()
if session:
await session.close()
self.set_session(None)

def get_default_request_strategy(self):
return AsyncRequestStrategy()

async def post(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
"""Send data and return response data from POST endpoint."""
return await self.get_request_strategy().post(endpoint, data=data, params=params, **kwargs)

async def get(self, endpoint: str, params: OptionalDict = None, **kwargs):
"""Return response data from GET endpoint."""
return await self.get_request_strategy().get(endpoint, params=params, **kwargs)

async def put(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
"""Send data to overwrite resource and return response data from PUT endpoint."""
return await self.get_request_strategy().put(endpoint, data=data, params=params, **kwargs)

async def patch(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
"""Send data to update resource and return response data from PATCH endpoint."""
return await self.get_request_strategy().patch(endpoint, data=data, params=params, **kwargs)

async def delete(self, endpoint: str, params: OptionalDict = None, **kwargs):
"""Remove resource with DELETE endpoint."""
return await self.get_request_strategy().delete(endpoint, params=params, **kwargs)
170 changes: 115 additions & 55 deletions apiclient/request_strategies.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,99 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable

import aiohttp
import requests

from apiclient.exceptions import UnexpectedError
from apiclient.response import RequestsResponse, Response
from apiclient.response import AioHttpResponse, RequestsResponse, Response
from apiclient.utils.typing import OptionalDict

if TYPE_CHECKING: # pragma: no cover
# Stupid way of getting around cyclic imports when
# using typehinting.
from apiclient import APIClient
from apiclient.client import AbstractClient


class BaseRequestStrategy:
def set_client(self, client: "APIClient"):
def set_client(self, client: "AbstractClient"):
self._client = client

def get_client(self) -> "APIClient":
def get_session(self):
return self.get_client().get_session()

def set_session(self, session: Any):
self.get_client().set_session(session)

def create_session(self): # pragma: no cover
"""Abstract method that will create a session object."""
raise NotImplementedError

def get_client(self) -> "AbstractClient":
return self._client

def post(self, *args, **kwargs): # pragma: no cover
def _get_request_params(self, params: OptionalDict) -> dict:
"""Return dictionary with any additional authentication query parameters."""
if params is None:
params = {}
params.update(self.get_client().get_default_query_params())
return params

def _get_request_headers(self, headers: OptionalDict) -> dict:
"""Return dictionary with any additional authentication headers."""
if headers is None:
headers = {}
headers.update(self.get_client().get_default_headers())
return headers

def _get_username_password_authentication(self):
return self.get_client().get_default_username_password_authentication()

def _get_formatted_data(self, data: OptionalDict):
return self.get_client().get_request_formatter().format(data)

def _get_request_timeout(self) -> float:
"""Return the number of seconds before the request times out."""
return self.get_client().get_request_timeout()

def _check_response(self, response: Response):
"""Raise a custom exception if the response is not OK."""
status_code = response.get_status_code()
if status_code < 200 or status_code >= 300:
self._handle_bad_response(response)

def _decode_response_data(self, response: Response):
return self.get_client().get_response_handler().get_request_data(response)

def _handle_bad_response(self, response: Response):
"""Convert the error into an understandable client exception."""
raise self.get_client().get_error_handler().get_exception(response)

def post(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): # pragma: no cover
raise NotImplementedError

def get(self, *args, **kwargs): # pragma: no cover
def get(self, endpoint: str, params: OptionalDict = None, **kwargs): # pragma: no cover
raise NotImplementedError

def put(self, *args, **kwargs): # pragma: no cover
def put(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): # pragma: no cover
raise NotImplementedError

def patch(self, *args, **kwargs): # pragma: no cover
def patch(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): # pragma: no cover
raise NotImplementedError

def delete(self, *args, **kwargs): # pragma: no cover
def delete(self, endpoint: str, params: OptionalDict = None, **kwargs): # pragma: no cover
raise NotImplementedError


class RequestStrategy(BaseRequestStrategy):
"""Requests strategy that uses the `requests` lib with a `requests.session`."""

def set_client(self, client: "APIClient"):
def set_client(self, client: "AbstractClient"):
super().set_client(client)
# Set a global `requests.session` on the parent client instance.
if self.get_session() is None:
self.set_session(requests.session())
self.set_session(self.create_session())

def get_session(self):
return self.get_client().get_session()

def set_session(self, session: requests.Session):
self.get_client().set_session(session)
def create_session(self) -> requests.Session:
return requests.session()

def post(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
"""Send data and return response data from POST endpoint."""
Expand Down Expand Up @@ -102,43 +146,6 @@ def _make_request(
self._check_response(response)
return self._decode_response_data(response)

def _get_request_params(self, params: OptionalDict) -> dict:
"""Return dictionary with any additional authentication query parameters."""
if params is None:
params = {}
params.update(self.get_client().get_default_query_params())
return params

def _get_request_headers(self, headers: OptionalDict) -> dict:
"""Return dictionary with any additional authentication headers."""
if headers is None:
headers = {}
headers.update(self.get_client().get_default_headers())
return headers

def _get_username_password_authentication(self):
return self.get_client().get_default_username_password_authentication()

def _get_formatted_data(self, data: OptionalDict):
return self.get_client().get_request_formatter().format(data)

def _get_request_timeout(self) -> float:
"""Return the number of seconds before the request times out."""
return self.get_client().get_request_timeout()

def _check_response(self, response: Response):
"""Raise a custom exception if the response is not OK."""
status_code = response.get_status_code()
if status_code < 200 or status_code >= 300:
self._handle_bad_response(response)

def _decode_response_data(self, response: Response):
return self.get_client().get_response_handler().get_request_data(response)

def _handle_bad_response(self, response: Response):
"""Convert the error into an understandable client exception."""
raise self.get_client().get_error_handler().get_exception(response)


class QueryParamPaginatedRequestStrategy(RequestStrategy):
"""Strategy for GET requests where pages are defined in query params."""
Expand Down Expand Up @@ -192,3 +199,56 @@ def get(self, endpoint: str, params: OptionalDict = None, **kwargs):

def get_next_page_url(self, response, previous_page_url: str) -> OptionalDict:
return self._next_page(response, previous_page_url)


class AsyncRequestStrategy(BaseRequestStrategy):
async def create_session(self) -> aiohttp.ClientSession:
return aiohttp.ClientSession()

def get_session(self) -> aiohttp.ClientSession:
return self.get_client().get_session()

async def post(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
return await self._make_request(
self.get_session().post, endpoint, data=data, params=params, **kwargs
)

async def get(self, endpoint: str, params: OptionalDict = None, **kwargs):
return await self._make_request(self.get_session().get, endpoint, params=params, **kwargs)

async def put(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
return await self._make_request(self.get_session().put, endpoint, data=data, params=params, **kwargs)

async def patch(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs):
return await self._make_request(
self.get_session().patch, endpoint, data=data, params=params, **kwargs
)

async def delete(self, endpoint: str, params: OptionalDict = None, **kwargs):
return await self._make_request(self.get_session().delete, endpoint, params=params, **kwargs)

async def _make_request(
self,
request_method: Callable,
endpoint: str,
params: OptionalDict = None,
headers: OptionalDict = None,
data: OptionalDict = None,
**kwargs,
) -> Response:
try:
async with request_method(
endpoint,
params=self._get_request_params(params),
headers=self._get_request_headers(headers),
auth=self._get_username_password_authentication(),
data=self._get_formatted_data(data),
timeout=self._get_request_timeout(),
**kwargs,
) as raw_response:
response = AioHttpResponse(raw_response, content=await raw_response.read())
except Exception as error:
raise UnexpectedError(f"Error when contacting '{endpoint}'") from error
else:
self._check_response(response)
return self._decode_response_data(response)
25 changes: 25 additions & 0 deletions apiclient/response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
from typing import Any

import aiohttp
import requests

from apiclient.utils.typing import JsonType
Expand Down Expand Up @@ -62,3 +64,26 @@ def get_status_reason(self) -> str:

def get_requested_url(self) -> str:
return self._response.url


class AioHttpResponse(RequestsResponse):
"""Implementation of the response for a requests.response type."""

def __init__(self, response: aiohttp.ClientResponse, content: bytes):
self._response = response
self._content = content
self._text = ""

def get_status_code(self) -> int:
return self._response.status

def get_raw_data(self) -> str:
if not self._text:
self._text = self._content.decode(self._response.get_encoding(), errors="strict")
return self._text

def get_json(self) -> JsonType:
return json.loads(self._text)

def get_requested_url(self) -> str:
return str(self._response.url)
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool:pytest]
addopts = --cov=apiclient/ --cov-fail-under=100 --cov-report html
addopts = --asyncio-mode=auto --cov=apiclient/ --cov-fail-under=100 --cov-report html
env =
ENDPOINT_BASE_URL=http://environment.com

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import setuptools

# Pinning tenacity as the api has changed slightly which breaks all tests.
application_dependencies = ["requests>=2.16", "tenacity>=5.1.0"]
application_dependencies = ["requests>=2.16", "aiohttp>=3.8", "tenacity>=5.1.0"]
prod_dependencies = []
test_dependencies = ["pytest", "pytest-env", "pytest-cov", "vcrpy", "requests-mock"]
test_dependencies = ["pytest", "pytest-env", "pytest-cov", "vcrpy", "requests-mock", "pytest-asyncio", "aioresponses"]
lint_dependencies = ["flake8", "flake8-docstrings", "black", "isort"]
docs_dependencies = []
dev_dependencies = test_dependencies + lint_dependencies + docs_dependencies + ["ipdb"]
Expand Down
Loading