Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): send retry count header #528

Merged
merged 1 commit into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 54 additions & 47 deletions src/modern_treasury/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,14 +401,7 @@ def _make_status_error(
) -> _exceptions.APIStatusError:
raise NotImplementedError()

def _remaining_retries(
self,
remaining_retries: Optional[int],
options: FinalRequestOptions,
) -> int:
return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries)

def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Headers:
custom_headers = options.headers or {}
headers_dict = _merge_mappings(self.default_headers, custom_headers)
self._validate_headers(headers_dict, custom_headers)
Expand All @@ -420,6 +413,8 @@ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()

headers.setdefault("x-stainless-retry-count", str(retries_taken))

return headers

def _prepare_url(self, url: str) -> URL:
Expand All @@ -441,6 +436,8 @@ def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
def _build_request(
self,
options: FinalRequestOptions,
*,
retries_taken: int = 0,
) -> httpx.Request:
if log.isEnabledFor(logging.DEBUG):
log.debug("Request options: %s", model_dump(options, exclude_unset=True))
Expand All @@ -456,7 +453,7 @@ def _build_request(
else:
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")

headers = self._build_headers(options)
headers = self._build_headers(options, retries_taken=retries_taken)
params = _merge_mappings(self.default_query, options.params)
content_type = headers.get("Content-Type")
files = options.files
Expand Down Expand Up @@ -939,20 +936,25 @@ def request(
stream: bool = False,
stream_cls: type[_StreamT] | None = None,
) -> ResponseT | _StreamT:
if remaining_retries is not None:
retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
else:
retries_taken = 0

return self._request(
cast_to=cast_to,
options=options,
stream=stream,
stream_cls=stream_cls,
remaining_retries=remaining_retries,
retries_taken=retries_taken,
)

def _request(
self,
*,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
remaining_retries: int | None,
retries_taken: int,
stream: bool,
stream_cls: type[_StreamT] | None,
) -> ResponseT | _StreamT:
Expand All @@ -964,8 +966,8 @@ def _request(
cast_to = self._maybe_override_cast_to(cast_to, options)
options = self._prepare_options(options)

retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)
remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
request = self._build_request(options, retries_taken=retries_taken)
self._prepare_request(request)

kwargs: HttpxSendArgs = {}
Expand All @@ -983,11 +985,11 @@ def _request(
except httpx.TimeoutException as err:
log.debug("Encountered httpx.TimeoutException", exc_info=True)

if retries > 0:
if remaining_retries > 0:
return self._retry_request(
input_options,
cast_to,
retries,
retries_taken=retries_taken,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
Expand All @@ -998,11 +1000,11 @@ def _request(
except Exception as err:
log.debug("Encountered Exception", exc_info=True)

if retries > 0:
if remaining_retries > 0:
return self._retry_request(
input_options,
cast_to,
retries,
retries_taken=retries_taken,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
Expand All @@ -1025,13 +1027,13 @@ def _request(
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
log.debug("Encountered httpx.HTTPStatusError", exc_info=True)

if retries > 0 and self._should_retry(err.response):
if remaining_retries > 0 and self._should_retry(err.response):
err.response.close()
return self._retry_request(
input_options,
cast_to,
retries,
err.response.headers,
retries_taken=retries_taken,
response_headers=err.response.headers,
stream=stream,
stream_cls=stream_cls,
)
Expand All @@ -1050,26 +1052,26 @@ def _request(
response=response,
stream=stream,
stream_cls=stream_cls,
retries_taken=options.get_max_retries(self.max_retries) - retries,
retries_taken=retries_taken,
)

def _retry_request(
self,
options: FinalRequestOptions,
cast_to: Type[ResponseT],
remaining_retries: int,
response_headers: httpx.Headers | None,
*,
retries_taken: int,
response_headers: httpx.Headers | None,
stream: bool,
stream_cls: type[_StreamT] | None,
) -> ResponseT | _StreamT:
remaining = remaining_retries - 1
if remaining == 1:
remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
if remaining_retries == 1:
log.debug("1 retry left")
else:
log.debug("%i retries left", remaining)
log.debug("%i retries left", remaining_retries)

timeout = self._calculate_retry_timeout(remaining, options, response_headers)
timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
log.info("Retrying request to %s in %f seconds", options.url, timeout)

# In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a
Expand All @@ -1079,7 +1081,7 @@ def _retry_request(
return self._request(
options=options,
cast_to=cast_to,
remaining_retries=remaining,
retries_taken=retries_taken + 1,
stream=stream,
stream_cls=stream_cls,
)
Expand Down Expand Up @@ -1511,12 +1513,17 @@ async def request(
stream_cls: type[_AsyncStreamT] | None = None,
remaining_retries: Optional[int] = None,
) -> ResponseT | _AsyncStreamT:
if remaining_retries is not None:
retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
else:
retries_taken = 0

return await self._request(
cast_to=cast_to,
options=options,
stream=stream,
stream_cls=stream_cls,
remaining_retries=remaining_retries,
retries_taken=retries_taken,
)

async def _request(
Expand All @@ -1526,7 +1533,7 @@ async def _request(
*,
stream: bool,
stream_cls: type[_AsyncStreamT] | None,
remaining_retries: int | None,
retries_taken: int,
) -> ResponseT | _AsyncStreamT:
if self._platform is None:
# `get_platform` can make blocking IO calls so we
Expand All @@ -1541,8 +1548,8 @@ async def _request(
cast_to = self._maybe_override_cast_to(cast_to, options)
options = await self._prepare_options(options)

retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)
remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
request = self._build_request(options, retries_taken=retries_taken)
await self._prepare_request(request)

kwargs: HttpxSendArgs = {}
Expand All @@ -1558,11 +1565,11 @@ async def _request(
except httpx.TimeoutException as err:
log.debug("Encountered httpx.TimeoutException", exc_info=True)

if retries > 0:
if remaining_retries > 0:
return await self._retry_request(
input_options,
cast_to,
retries,
retries_taken=retries_taken,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
Expand All @@ -1573,11 +1580,11 @@ async def _request(
except Exception as err:
log.debug("Encountered Exception", exc_info=True)

if retries > 0:
if retries_taken > 0:
return await self._retry_request(
input_options,
cast_to,
retries,
retries_taken=retries_taken,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
Expand All @@ -1595,13 +1602,13 @@ async def _request(
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
log.debug("Encountered httpx.HTTPStatusError", exc_info=True)

if retries > 0 and self._should_retry(err.response):
if remaining_retries > 0 and self._should_retry(err.response):
await err.response.aclose()
return await self._retry_request(
input_options,
cast_to,
retries,
err.response.headers,
retries_taken=retries_taken,
response_headers=err.response.headers,
stream=stream,
stream_cls=stream_cls,
)
Expand All @@ -1620,34 +1627,34 @@ async def _request(
response=response,
stream=stream,
stream_cls=stream_cls,
retries_taken=options.get_max_retries(self.max_retries) - retries,
retries_taken=retries_taken,
)

async def _retry_request(
self,
options: FinalRequestOptions,
cast_to: Type[ResponseT],
remaining_retries: int,
response_headers: httpx.Headers | None,
*,
retries_taken: int,
response_headers: httpx.Headers | None,
stream: bool,
stream_cls: type[_AsyncStreamT] | None,
) -> ResponseT | _AsyncStreamT:
remaining = remaining_retries - 1
if remaining == 1:
remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
if remaining_retries == 1:
log.debug("1 retry left")
else:
log.debug("%i retries left", remaining)
log.debug("%i retries left", remaining_retries)

timeout = self._calculate_retry_timeout(remaining, options, response_headers)
timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
log.info("Retrying request to %s in %f seconds", options.url, timeout)

await anyio.sleep(timeout)

return await self._request(
options=options,
cast_to=cast_to,
remaining_retries=remaining,
retries_taken=retries_taken + 1,
stream=stream,
stream_cls=stream_cls,
)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
response = client.counterparties.with_raw_response.create(name="name")

assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success

@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("modern_treasury._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
Expand All @@ -1002,6 +1003,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:

with client.counterparties.with_streaming_response.create(name="name") as response:
assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success


class TestAsyncModernTreasury:
Expand Down Expand Up @@ -1942,6 +1944,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
response = await client.counterparties.with_raw_response.create(name="name")

assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success

@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("modern_treasury._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
Expand All @@ -1965,3 +1968,4 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:

async with client.counterparties.with_streaming_response.create(name="name") as response:
assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success