diff --git a/httpx/_client.py b/httpx/_client.py index 5801abe4d0..76325c147d 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -45,12 +45,7 @@ TimeoutTypes, ) from ._urls import URL, QueryParams -from ._utils import ( - URLPattern, - get_environment_proxies, - is_https_redirect, - same_origin, -) +from ._utils import URLPattern, get_environment_proxies if typing.TYPE_CHECKING: import ssl # pragma: no cover @@ -63,6 +58,38 @@ U = typing.TypeVar("U", bound="AsyncClient") +def _is_https_redirect(url: URL, location: URL) -> bool: + """ + Return 'True' if 'location' is a HTTPS upgrade of 'url' + """ + if url.host != location.host: + return False + + return ( + url.scheme == "http" + and _port_or_default(url) == 80 + and location.scheme == "https" + and _port_or_default(location) == 443 + ) + + +def _port_or_default(url: URL) -> int | None: + if url.port is not None: + return url.port + return {"http": 80, "https": 443}.get(url.scheme) + + +def _same_origin(url: URL, other: URL) -> bool: + """ + Return 'True' if the given URLs share the same origin. + """ + return ( + url.scheme == other.scheme + and url.host == other.host + and _port_or_default(url) == _port_or_default(other) + ) + + class UseClientDefault: """ For some parameters such as `auth=...` and `timeout=...` we need to be able @@ -521,8 +548,8 @@ def _redirect_headers(self, request: Request, url: URL, method: str) -> Headers: """ headers = Headers(request.headers) - if not same_origin(url, request.url): - if not is_https_redirect(request.url, url): + if not _same_origin(url, request.url): + if not _is_https_redirect(request.url, url): # Strip Authorization headers when responses are redirected # away from the origin. (Except for direct HTTP to HTTPS redirects.) headers.pop("Authorization", None) diff --git a/httpx/_utils.py b/httpx/_utils.py index 9a1ed54749..7fe827da4d 100644 --- a/httpx/_utils.py +++ b/httpx/_utils.py @@ -27,38 +27,6 @@ def primitive_value_to_str(value: PrimitiveData) -> str: return str(value) -def port_or_default(url: URL) -> int | None: - if url.port is not None: - return url.port - return {"http": 80, "https": 443}.get(url.scheme) - - -def same_origin(url: URL, other: URL) -> bool: - """ - Return 'True' if the given URLs share the same origin. - """ - return ( - url.scheme == other.scheme - and url.host == other.host - and port_or_default(url) == port_or_default(other) - ) - - -def is_https_redirect(url: URL, location: URL) -> bool: - """ - Return 'True' if 'location' is a HTTPS upgrade of 'url' - """ - if url.host != location.host: - return False - - return ( - url.scheme == "http" - and port_or_default(url) == 80 - and location.scheme == "https" - and port_or_default(location) == 443 - ) - - def get_environment_proxies() -> dict[str, str | None]: """Gets proxy information from the environment""" diff --git a/tests/client/test_headers.py b/tests/client/test_headers.py index b8d2976700..47f5a4d731 100755 --- a/tests/client/test_headers.py +++ b/tests/client/test_headers.py @@ -235,3 +235,59 @@ def test_host_with_non_default_port_in_url(): def test_request_auto_headers(): request = httpx.Request("GET", "https://www.example.org/") assert "host" in request.headers + + +def test_same_origin(): + origin = httpx.URL("https://example.com") + request = httpx.Request("GET", "HTTPS://EXAMPLE.COM:443") + + client = httpx.Client() + headers = client._redirect_headers(request, origin, "GET") + + assert headers["Host"] == request.url.netloc.decode("ascii") + + +def test_not_same_origin(): + origin = httpx.URL("https://example.com") + request = httpx.Request("GET", "HTTP://EXAMPLE.COM:80") + + client = httpx.Client() + headers = client._redirect_headers(request, origin, "GET") + + assert headers["Host"] == origin.netloc.decode("ascii") + + +def test_is_https_redirect(): + url = httpx.URL("https://example.com") + request = httpx.Request( + "GET", "http://example.com", headers={"Authorization": "empty"} + ) + + client = httpx.Client() + headers = client._redirect_headers(request, url, "GET") + + assert "Authorization" in headers + + +def test_is_not_https_redirect(): + url = httpx.URL("https://www.example.com") + request = httpx.Request( + "GET", "http://example.com", headers={"Authorization": "empty"} + ) + + client = httpx.Client() + headers = client._redirect_headers(request, url, "GET") + + assert "Authorization" not in headers + + +def test_is_not_https_redirect_if_not_default_ports(): + url = httpx.URL("https://example.com:1337") + request = httpx.Request( + "GET", "http://example.com:9999", headers={"Authorization": "empty"} + ) + + client = httpx.Client() + headers = client._redirect_headers(request, url, "GET") + + assert "Authorization" not in headers diff --git a/tests/test_utils.py b/tests/test_utils.py index 3e2abdef28..f9c215f65a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,10 +6,7 @@ import pytest import httpx -from httpx._utils import ( - URLPattern, - get_environment_proxies, -) +from httpx._utils import URLPattern, get_environment_proxies @pytest.mark.parametrize( @@ -115,62 +112,6 @@ def test_get_environment_proxies(environment, proxies): assert get_environment_proxies() == proxies -def test_same_origin(): - origin = httpx.URL("https://example.com") - request = httpx.Request("GET", "HTTPS://EXAMPLE.COM:443") - - client = httpx.Client() - headers = client._redirect_headers(request, origin, "GET") - - assert headers["Host"] == request.url.netloc.decode("ascii") - - -def test_not_same_origin(): - origin = httpx.URL("https://example.com") - request = httpx.Request("GET", "HTTP://EXAMPLE.COM:80") - - client = httpx.Client() - headers = client._redirect_headers(request, origin, "GET") - - assert headers["Host"] == origin.netloc.decode("ascii") - - -def test_is_https_redirect(): - url = httpx.URL("https://example.com") - request = httpx.Request( - "GET", "http://example.com", headers={"Authorization": "empty"} - ) - - client = httpx.Client() - headers = client._redirect_headers(request, url, "GET") - - assert "Authorization" in headers - - -def test_is_not_https_redirect(): - url = httpx.URL("https://www.example.com") - request = httpx.Request( - "GET", "http://example.com", headers={"Authorization": "empty"} - ) - - client = httpx.Client() - headers = client._redirect_headers(request, url, "GET") - - assert "Authorization" not in headers - - -def test_is_not_https_redirect_if_not_default_ports(): - url = httpx.URL("https://example.com:1337") - request = httpx.Request( - "GET", "http://example.com:9999", headers={"Authorization": "empty"} - ) - - client = httpx.Client() - headers = client._redirect_headers(request, url, "GET") - - assert "Authorization" not in headers - - @pytest.mark.parametrize( ["pattern", "url", "expected"], [