diff --git a/solnlib/splunk_rest_client.py b/solnlib/splunk_rest_client.py index 7e858248..1f44ef58 100644 --- a/solnlib/splunk_rest_client.py +++ b/solnlib/splunk_rest_client.py @@ -71,6 +71,8 @@ def _request_handler(context): 'cert_file': string 'pool_connections', int, 'pool_maxsize', int, + 'max_retries': int, + 'retry_status_codes': list, } :type content: dict """ @@ -102,25 +104,32 @@ def _request_handler(context): else: cert = None - retries = Retry( - total=MAX_REQUEST_RETRIES, - backoff_factor=0.3, - status_forcelist=[500, 502, 503, 504], - allowed_methods=["GET", "POST", "PUT", "DELETE"], - raise_on_status=False, - ) - if context.get("pool_connections", 0): - logging.info("Use HTTP connection pooling") - session = requests.Session() - adapter = requests.adapters.HTTPAdapter( - max_retries=retries, - pool_connections=context.get("pool_connections", 10), - pool_maxsize=context.get("pool_maxsize", 10), + def adapter(): + retries = Retry( + total=context.get("max_retries", MAX_REQUEST_RETRIES), + backoff_factor=0.3, + status_forcelist=context.get("retry_status_codes", [500, 502, 503, 504]), + allowed_methods=["GET", "POST", "PUT", "DELETE"], + raise_on_status=False, ) - session.mount("https://", adapter) - req_func = session.request - else: - req_func = requests.request + + adapter_args = { + "max_retries": retries, + } + + # By default, pool_connections and pool_maxsize are set to 10 in urllib3 + if "pool_connections" in context: + adapter_args["pool_connections"] = context["pool_connections"] + if "pool_maxsize" in context: + adapter_args["pool_maxsize"] = context["pool_maxsize"] + + return requests.adapters.HTTPAdapter(**adapter_args) + + session = requests.Session() + session.mount("http://", adapter()) + session.mount("https://", adapter()) + + req_func = session.request def request(url, message, **kwargs): """ diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 00000000..8b3bf665 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,77 @@ +import json +import socket +from contextlib import closing +from http.server import BaseHTTPRequestHandler, HTTPServer +from threading import Thread + +import pytest + + +@pytest.fixture(scope="session") +def http_mock_server(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + port = s.getsockname()[1] + + class Mock: + def __init__(self, host, port): + self.host = host + self.port = port + self.get_func = None + + def get(self, func): + self.get_func = func + return func + + mock = Mock("localhost", port) + + class RequestArg: + def __init__(self): + self.headers = { + "Content-Type": "application/json", + } + self.response_code = 200 + + def send_header(self, key, value): + self.headers[key] = value + + def send_response(self, code): + self.response_code = code + + class Handler(BaseHTTPRequestHandler): + def do_GET(self): + if mock.get_func is None: + self.send_response(404) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"error": "Not Found"}).encode("utf-8")) + return + + request = RequestArg() + response = mock.get_func(request) + + self.send_response(request.response_code) + + for key, value in request.headers.items(): + self.send_header(key, value) + + self.end_headers() + + if isinstance(response, dict): + response = json.dumps(response) + + self.wfile.write(response.encode("utf-8")) + + server_address = ("", mock.port) + httpd = HTTPServer(server_address, Handler) + + thread = Thread(target=httpd.serve_forever) + thread.setDaemon(True) + thread.start() + + yield mock + + httpd.shutdown() + httpd.server_close() + thread.join() diff --git a/tests/unit/test_splunk_rest_client.py b/tests/unit/test_splunk_rest_client.py index de43dbdd..25d9d10d 100644 --- a/tests/unit/test_splunk_rest_client.py +++ b/tests/unit/test_splunk_rest_client.py @@ -17,6 +17,8 @@ from unittest import mock import pytest +from splunklib.binding import HTTPError + from solnlib.splunk_rest_client import MAX_REQUEST_RETRIES from requests.exceptions import ConnectionError @@ -109,3 +111,67 @@ def test_request_retry(http_conn_pool, http_resp, mock_get_splunkd_access_info): http_conn_pool.side_effect = side_effects with pytest.raises(ConnectionError): rest_client.get("test") + + +@pytest.mark.parametrize("error_code", [429, 500, 503]) +def test_request_throttling(http_mock_server, error_code): + @http_mock_server.get + def throttling(request): + """Mock endpoint to simulate request throttling. + + The endpoint will return an error status code for the first 5 + requests, and a 200 status code for subsequent requests. + """ + number = getattr(throttling, "call_count", 0) + throttling.call_count = number + 1 + + if number < 2: + request.send_response(error_code) + request.send_header("Retry-After", "1") + return {"error": f"Error {number}"} + + return {"content": "Success"} + + rest_client = SplunkRestClient( + "msg_name_1", + "session_key", + "_", + scheme="http", + host="localhost", + port=http_mock_server.port, + ) + + resp = rest_client.get("test") + assert resp.status == 200 + assert resp.body.read().decode("utf-8") == '{"content": "Success"}' + + +@pytest.mark.parametrize("error_code", [429, 500, 503]) +def test_request_throttling_exceeded(http_mock_server, error_code): + @http_mock_server.get + def throttling(request): + """Mock endpoint to simulate request throttling. + + The endpoint will always return an error status code. + """ + number = getattr(throttling, "call_count", 0) + throttling.call_count = number + 1 + + request.send_response(error_code) + request.send_header("Retry-After", "1") + return {"error": f"Error {number}"} + + rest_client = SplunkRestClient( + "msg_name_1", + "session_key", + "_", + scheme="http", + host="localhost", + port=http_mock_server.port, + ) + + with pytest.raises(HTTPError) as ex: + rest_client.get("test") + + assert ex.value.status == error_code + assert ex.value.body.decode("utf-8") == '{"error": "Error 5"}'