diff --git a/docs/sphinx/installation.rst b/docs/sphinx/installation.rst index 120cac7..8955bd5 100644 --- a/docs/sphinx/installation.rst +++ b/docs/sphinx/installation.rst @@ -11,4 +11,4 @@ Install the ``requests`` package to use :class:`elastic_transport.RequestsHttpNo Install the ``aiohttp`` package to use :class:`elastic_transport.AiohttpHttpNode`. -Install the ``httpx`` package to use :class:`elastic_transport.HttpxAsyncHttpNode`. +Install the ``httpx`` package to use :class:`elastic_transport.HttpxHttpNode` or :class:`elastic_transport.HttpxAsyncHttpNode`. diff --git a/docs/sphinx/nodes.rst b/docs/sphinx/nodes.rst index d1b50e6..334b2c2 100644 --- a/docs/sphinx/nodes.rst +++ b/docs/sphinx/nodes.rst @@ -22,6 +22,9 @@ Node classes .. autoclass:: AiohttpHttpNode :members: +.. autoclass:: HttpxHttpNode + :members: + .. autoclass:: HttpxAsyncHttpNode :members: diff --git a/elastic_transport/__init__.py b/elastic_transport/__init__.py index d65892a..051d435 100644 --- a/elastic_transport/__init__.py +++ b/elastic_transport/__init__.py @@ -37,6 +37,7 @@ BaseAsyncNode, BaseNode, HttpxAsyncHttpNode, + HttpxHttpNode, RequestsHttpNode, Urllib3HttpNode, ) @@ -74,6 +75,7 @@ "HeadApiResponse", "HttpHeaders", "HttpxAsyncHttpNode", + "HttpxHttpNode", "JsonSerializer", "ListApiResponse", "NdjsonSerializer", diff --git a/elastic_transport/_node/__init__.py b/elastic_transport/_node/__init__.py index e5af033..e718779 100644 --- a/elastic_transport/_node/__init__.py +++ b/elastic_transport/_node/__init__.py @@ -18,7 +18,7 @@ from ._base import BaseNode, NodeApiResponse from ._base_async import BaseAsyncNode from ._http_aiohttp import AiohttpHttpNode -from ._http_httpx import HttpxAsyncHttpNode +from ._http_httpx import HttpxAsyncHttpNode, HttpxHttpNode from ._http_requests import RequestsHttpNode from ._http_urllib3 import Urllib3HttpNode @@ -29,5 +29,6 @@ "NodeApiResponse", "RequestsHttpNode", "Urllib3HttpNode", + "HttpxHttpNode", "HttpxAsyncHttpNode", ] diff --git a/elastic_transport/_node/_http_httpx.py b/elastic_transport/_node/_http_httpx.py index 04ceb60..642d0f9 100644 --- a/elastic_transport/_node/_http_httpx.py +++ b/elastic_transport/_node/_http_httpx.py @@ -30,6 +30,7 @@ BUILTIN_EXCEPTIONS, DEFAULT_CA_CERTS, RERAISE_EXCEPTIONS, + BaseNode, NodeApiResponse, ssl_context_from_node_config, ) @@ -45,6 +46,161 @@ _HTTPX_META_VERSION = "" +class HttpxHttpNode(BaseNode): + _CLIENT_META_HTTP_CLIENT = ("hx", _HTTPX_META_VERSION) + + def __init__(self, config: NodeConfig): + if not _HTTPX_AVAILABLE: # pragma: nocover + raise ValueError("You must have 'httpx' installed to use HttpxNode") + super().__init__(config) + + if config.ssl_assert_fingerprint: + raise ValueError( + "httpx does not support certificate pinning. https://github.com/encode/httpx/issues/761" + ) + + ssl_context: Union[ssl.SSLContext, Literal[False]] = False + if config.scheme == "https": + if config.ssl_context is not None: + ssl_context = ssl_context_from_node_config(config) + else: + ssl_context = ssl_context_from_node_config(config) + + ca_certs = ( + DEFAULT_CA_CERTS if config.ca_certs is None else config.ca_certs + ) + if config.verify_certs: + if not ca_certs: + raise ValueError( + "Root certificates are missing for certificate " + "validation. Either pass them in using the ca_certs parameter or " + "install certifi to use it automatically." + ) + else: + if config.ssl_show_warn: + warnings.warn( + f"Connecting to {self.base_url!r} using TLS with verify_certs=False is insecure", + stacklevel=warn_stacklevel(), + category=SecurityWarning, + ) + + if ca_certs is not None: + if os.path.isfile(ca_certs): + ssl_context.load_verify_locations(cafile=ca_certs) + elif os.path.isdir(ca_certs): + ssl_context.load_verify_locations(capath=ca_certs) + else: + raise ValueError("ca_certs parameter is not a path") + + # Use client_cert and client_key variables for SSL certificate configuration. + if config.client_cert and not os.path.isfile(config.client_cert): + raise ValueError("client_cert is not a path to a file") + if config.client_key and not os.path.isfile(config.client_key): + raise ValueError("client_key is not a path to a file") + if config.client_cert and config.client_key: + ssl_context.load_cert_chain(config.client_cert, config.client_key) + elif config.client_cert: + ssl_context.load_cert_chain(config.client_cert) + + self.client = httpx.Client( + base_url=f"{config.scheme}://{config.host}:{config.port}", + limits=httpx.Limits(max_connections=config.connections_per_node), + verify=ssl_context or False, + timeout=config.request_timeout, + ) + + def perform_request( + self, + method: str, + target: str, + body: Optional[bytes] = None, + headers: Optional[HttpHeaders] = None, + request_timeout: Union[DefaultType, Optional[float]] = DEFAULT, + ) -> NodeApiResponse: + resolved_headers = self._headers.copy() + if headers: + resolved_headers.update(headers) + + if body: + if self._http_compress: + resolved_body = gzip.compress(body) + resolved_headers["content-encoding"] = "gzip" + else: + resolved_body = body + else: + resolved_body = None + + try: + start = time.perf_counter() + if request_timeout is DEFAULT: + resp = self.client.request( + method, + target, + content=resolved_body, + headers=dict(resolved_headers), + ) + else: + resp = self.client.request( + method, + target, + content=resolved_body, + headers=dict(resolved_headers), + timeout=request_timeout, + ) + response_body = resp.read() + duration = time.perf_counter() - start + except RERAISE_EXCEPTIONS + BUILTIN_EXCEPTIONS: + raise + except Exception as e: + err: Exception + if isinstance(e, (TimeoutError, httpx.TimeoutException)): + err = ConnectionTimeout( + "Connection timed out during request", errors=(e,) + ) + elif isinstance(e, ssl.SSLError): + err = TlsError(str(e), errors=(e,)) + # Detect SSL errors for httpx v0.28.0+ + # Needed until https://github.com/encode/httpx/issues/3350 is fixed + elif isinstance(e, httpx.ConnectError) and e.__cause__: + context = e.__cause__.__context__ + if isinstance(context, ssl.SSLError): + err = TlsError(str(context), errors=(e,)) + else: + err = ConnectionError(str(e), errors=(e,)) + else: + err = ConnectionError(str(e), errors=(e,)) + self._log_request( + method=method, + target=target, + headers=resolved_headers, + body=body, + exception=err, + ) + raise err from None + + meta = ApiResponseMeta( + resp.status_code, + resp.http_version, + HttpHeaders(resp.headers), + duration, + self.config, + ) + + self._log_request( + method=method, + target=target, + headers=resolved_headers, + body=body, + meta=meta, + response=response_body, + ) + + return NodeApiResponse(meta, response_body) + + def close(self) -> None: + self.client.close() + + class HttpxAsyncHttpNode(BaseAsyncNode): _CLIENT_META_HTTP_CLIENT = ("hx", _HTTPX_META_VERSION) diff --git a/elastic_transport/_transport.py b/elastic_transport/_transport.py index 32fb9e0..c1bf093 100644 --- a/elastic_transport/_transport.py +++ b/elastic_transport/_transport.py @@ -56,6 +56,7 @@ AiohttpHttpNode, BaseNode, HttpxAsyncHttpNode, + HttpxHttpNode, RequestsHttpNode, Urllib3HttpNode, ) @@ -70,6 +71,7 @@ "urllib3": Urllib3HttpNode, "requests": RequestsHttpNode, "aiohttp": AiohttpHttpNode, + "httpx": HttpxHttpNode, "httpxasync": HttpxAsyncHttpNode, } # These are HTTP status errors that shouldn't be considered diff --git a/tests/async_/test_async_transport.py b/tests/async_/test_async_transport.py index 24a869c..fdfcacb 100644 --- a/tests/async_/test_async_transport.py +++ b/tests/async_/test_async_transport.py @@ -315,7 +315,7 @@ async def test_node_class_as_string(): AsyncTransport([NodeConfig("http", "localhost", 80)], node_class="huh?") assert str(e.value) == ( "Unknown option for node_class: 'huh?'. " - "Available options are: 'aiohttp', 'httpxasync', 'requests', 'urllib3'" + "Available options are: 'aiohttp', 'httpx', 'httpxasync', 'requests', 'urllib3'" ) diff --git a/tests/node/test_base.py b/tests/node/test_base.py index 8d2d667..e9c59cf 100644 --- a/tests/node/test_base.py +++ b/tests/node/test_base.py @@ -20,6 +20,7 @@ from elastic_transport import ( AiohttpHttpNode, HttpxAsyncHttpNode, + HttpxHttpNode, NodeConfig, RequestsHttpNode, Urllib3HttpNode, @@ -28,7 +29,14 @@ @pytest.mark.parametrize( - "node_cls", [Urllib3HttpNode, RequestsHttpNode, AiohttpHttpNode, HttpxAsyncHttpNode] + "node_cls", + [ + Urllib3HttpNode, + RequestsHttpNode, + AiohttpHttpNode, + HttpxHttpNode, + HttpxAsyncHttpNode, + ], ) def test_unknown_parameter(node_cls): with pytest.raises(TypeError): diff --git a/tests/node/test_http_httpx.py b/tests/node/test_http_httpx.py index ce6e7f4..341cc1d 100644 --- a/tests/node/test_http_httpx.py +++ b/tests/node/test_http_httpx.py @@ -22,11 +22,129 @@ import pytest import respx -from elastic_transport import HttpxAsyncHttpNode, NodeConfig +from elastic_transport import HttpxAsyncHttpNode, HttpxHttpNode, NodeConfig from elastic_transport._node._base import DEFAULT_USER_AGENT -def create_node(node_config: NodeConfig): +def create_sync_node(node_config: NodeConfig): + return HttpxHttpNode(node_config) + + +class TestHttpxNodeCreation: + def test_ssl_context(self): + ssl_context = ssl.create_default_context() + with warnings.catch_warnings(record=True) as w: + node = create_sync_node( + NodeConfig( + scheme="https", + host="localhost", + port=80, + ssl_context=ssl_context, + ) + ) + assert node.client._transport._pool._ssl_context is ssl_context + assert len(w) == 0 + + def test_uses_https_if_verify_certs_is_off(self): + with warnings.catch_warnings(record=True) as w: + _ = create_sync_node( + NodeConfig("https", "localhost", 443, verify_certs=False) + ) + assert ( + str(w[0].message) + == "Connecting to 'https://localhost:443' using TLS with verify_certs=False is insecure" + ) + + def test_no_warn_when_uses_https_if_verify_certs_is_off(self): + with warnings.catch_warnings(record=True) as w: + _ = create_sync_node( + NodeConfig( + "https", + "localhost", + 443, + verify_certs=False, + ssl_show_warn=False, + ) + ) + assert 0 == len(w) + + def test_ca_certs_with_verify_ssl_false_raises_error(self): + with pytest.raises(ValueError) as exc: + create_sync_node( + NodeConfig( + "https", + "localhost", + 443, + ca_certs="/ca/certs", + verify_certs=False, + ) + ) + assert ( + str(exc.value) == "You cannot use 'ca_certs' when 'verify_certs=False'" + ) + + +class TestHttpxNode: + @respx.mock + def test_simple_request(self): + node = create_sync_node(NodeConfig(scheme="http", host="localhost", port=80)) + respx.get("http://localhost/index") + node.perform_request("GET", "/index", b"hello world", headers={"key": "value"}) + request = respx.calls.last.request + assert request.content == b"hello world" + assert { + "key": "value", + "connection": "keep-alive", + "user-agent": DEFAULT_USER_AGENT, + }.items() <= request.headers.items() + + @respx.mock + def test_compression(self): + node = create_sync_node( + NodeConfig(scheme="http", host="localhost", port=80, http_compress=True) + ) + respx.get("http://localhost/index") + node.perform_request("GET", "/index", b"hello world") + request = respx.calls.last.request + assert gzip.decompress(request.content) == b"hello world" + assert {"content-encoding": "gzip"}.items() <= request.headers.items() + + @respx.mock + def test_default_timeout(self): + node = create_sync_node( + NodeConfig(scheme="http", host="localhost", port=80, request_timeout=10) + ) + respx.get("http://localhost/index") + node.perform_request("GET", "/index", b"hello world") + request = respx.calls.last.request + assert request.extensions["timeout"]["connect"] == 10 + + @respx.mock + def test_overwritten_timeout(self): + node = create_sync_node( + NodeConfig(scheme="http", host="localhost", port=80, request_timeout=10) + ) + respx.get("http://localhost/index") + node.perform_request("GET", "/index", b"hello world", request_timeout=15) + request = respx.calls.last.request + assert request.extensions["timeout"]["connect"] == 15 + + @respx.mock + def test_merge_headers(self): + node = create_sync_node( + NodeConfig("http", "localhost", 80, headers={"h1": "v1", "h2": "v2"}) + ) + respx.get("http://localhost/index") + node.perform_request( + "GET", "/index", b"hello world", headers={"h2": "v2p", "h3": "v3"} + ) + request = respx.calls.last.request + assert request.headers["h1"] == "v1" + assert request.headers["h2"] == "v2p" + assert request.headers["h3"] == "v3" + + +def create_async_node(node_config: NodeConfig): return HttpxAsyncHttpNode(node_config) @@ -34,7 +152,7 @@ class TestHttpxAsyncNodeCreation: def test_ssl_context(self): ssl_context = ssl.create_default_context() with warnings.catch_warnings(record=True) as w: - node = create_node( + node = create_async_node( NodeConfig( scheme="https", host="localhost", @@ -47,7 +165,9 @@ def test_ssl_context(self): def test_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: - _ = create_node(NodeConfig("https", "localhost", 443, verify_certs=False)) + _ = create_async_node( + NodeConfig("https", "localhost", 443, verify_certs=False) + ) assert ( str(w[0].message) == "Connecting to 'https://localhost:443' using TLS with verify_certs=False is insecure" @@ -55,7 +175,7 @@ def test_uses_https_if_verify_certs_is_off(self): def test_no_warn_when_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: - _ = create_node( + _ = create_async_node( NodeConfig( "https", "localhost", @@ -68,7 +188,7 @@ def test_no_warn_when_uses_https_if_verify_certs_is_off(self): def test_ca_certs_with_verify_ssl_false_raises_error(self): with pytest.raises(ValueError) as exc: - create_node( + create_async_node( NodeConfig( "https", "localhost", @@ -86,7 +206,7 @@ def test_ca_certs_with_verify_ssl_false_raises_error(self): class TestHttpxAsyncNode: @respx.mock async def test_simple_request(self): - node = create_node(NodeConfig(scheme="http", host="localhost", port=80)) + node = create_async_node(NodeConfig(scheme="http", host="localhost", port=80)) respx.get("http://localhost/index") await node.perform_request( "GET", "/index", b"hello world", headers={"key": "value"} @@ -101,7 +221,7 @@ async def test_simple_request(self): @respx.mock async def test_compression(self): - node = create_node( + node = create_async_node( NodeConfig(scheme="http", host="localhost", port=80, http_compress=True) ) respx.get("http://localhost/index") @@ -112,7 +232,7 @@ async def test_compression(self): @respx.mock async def test_default_timeout(self): - node = create_node( + node = create_async_node( NodeConfig(scheme="http", host="localhost", port=80, request_timeout=10) ) respx.get("http://localhost/index") @@ -122,7 +242,7 @@ async def test_default_timeout(self): @respx.mock async def test_overwritten_timeout(self): - node = create_node( + node = create_async_node( NodeConfig(scheme="http", host="localhost", port=80, request_timeout=10) ) respx.get("http://localhost/index") @@ -132,7 +252,7 @@ async def test_overwritten_timeout(self): @respx.mock async def test_merge_headers(self): - node = create_node( + node = create_async_node( NodeConfig("http", "localhost", 80, headers={"h1": "v1", "h2": "v2"}) ) respx.get("http://localhost/index") @@ -145,9 +265,10 @@ async def test_merge_headers(self): assert request.headers["h3"] == "v3" -def test_ssl_assert_fingerprint(cert_fingerprint, httpbin_secure): +@pytest.mark.parametrize("node_class", [HttpxHttpNode, HttpxAsyncHttpNode]) +def test_ssl_assert_fingerprint(node_class, cert_fingerprint, httpbin_secure): with pytest.raises(ValueError, match="httpx does not support certificate pinning"): - HttpxAsyncHttpNode( + node_class( NodeConfig( scheme="https", host=httpbin_secure.host, diff --git a/tests/node/test_tls_versions.py b/tests/node/test_tls_versions.py index e687d9f..2e7e553 100644 --- a/tests/node/test_tls_versions.py +++ b/tests/node/test_tls_versions.py @@ -24,6 +24,7 @@ from elastic_transport import ( AiohttpHttpNode, HttpxAsyncHttpNode, + HttpxHttpNode, NodeConfig, RequestsHttpNode, TlsError, @@ -38,7 +39,13 @@ node_classes = pytest.mark.parametrize( "node_class", - [AiohttpHttpNode, Urllib3HttpNode, RequestsHttpNode, HttpxAsyncHttpNode], + [ + AiohttpHttpNode, + Urllib3HttpNode, + RequestsHttpNode, + HttpxHttpNode, + HttpxAsyncHttpNode, + ], ) supported_version_params = [ diff --git a/tests/test_transport.py b/tests/test_transport.py index 08d8e04..dd2754d 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -328,7 +328,7 @@ def test_node_class_as_string(): Transport([NodeConfig("http", "localhost", 80)], node_class="huh?") assert str(e.value) == ( "Unknown option for node_class: 'huh?'. " - "Available options are: 'aiohttp', 'httpxasync', 'requests', 'urllib3'" + "Available options are: 'aiohttp', 'httpx', 'httpxasync', 'requests', 'urllib3'" ) @@ -364,9 +364,9 @@ def test_transport_client_meta_node_class(node_class): assert ( t._transport_client_meta[3] == t.node_pool.node_class._CLIENT_META_HTTP_CLIENT ) - assert t._transport_client_meta[3][0] in ("ur", "rq") + assert t._transport_client_meta[3][0] in ("ur", "rq", "hx") assert re.match( - r"^et=[0-9.]+p?,py=[0-9.]+p?,t=[0-9.]+p?,(?:ur|rq)=[0-9.]+p?$", + r"^et=[0-9.]+p?,py=[0-9.]+p?,t=[0-9.]+p?,(?:ur|rq|hx)=[0-9.]+p?$", ",".join(f"{k}={v}" for k, v in t._transport_client_meta), )