Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HTTP proxy support reworked #751

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- run: tox -e py37
py38:
docker:
- image: circleci/python:3.8.0rc1
- image: circleci/python:3.8
steps:
# Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway.
- run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc
Expand Down
2 changes: 1 addition & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Client

.. automodule:: websockets.client

.. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds)
.. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, proxy_uri=USE_SYSTEM_PROXY, proxy_ssl=None, **kwds)
:async:

.. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bdist_wheel]
python-tag = py36.py37
python-tag = py36.py37.py38

[metadata]
license_file = LICENSE
Expand Down
2 changes: 2 additions & 0 deletions src/websockets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
"InvalidURI",
"NegotiationError",
"Origin",
"parse_proxy_uri",
"parse_uri",
"PayloadTooBig",
"ProtocolError",
"ProxyURI",
"RedirectHandshake",
"SecurityError",
"serve",
Expand Down
137 changes: 125 additions & 12 deletions src/websockets/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import collections.abc
import functools
import logging
import urllib.request
import warnings
from ssl import Purpose, SSLContext, create_default_context
from types import TracebackType
from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast
from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, Union, cast

from .exceptions import (
InvalidHandshake,
Expand All @@ -33,11 +35,13 @@
from .http import USER_AGENT, Headers, HeadersLike, read_response
from .protocol import WebSocketCommonProtocol
from .typing import ExtensionHeader, Origin, Subprotocol
from .uri import WebSocketURI, parse_uri
from .uri import ProxyURI, WebSocketURI, parse_proxy_uri, parse_uri


__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"]

USE_SYSTEM_PROXY = object()

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -223,6 +227,70 @@ def process_subprotocol(

return subprotocol

async def proxy_connect(
self,
proxy_uri: ProxyURI,
wsuri: WebSocketURI,
ssl: Optional[Union[SSLContext, bool]] = None,
server_hostname: Optional[str] = None,
) -> None:
"""
Issue a CONNECT request, read a response and upgrade the connection
to TLS, if necessary.

:param proxy_uri: the URI of the HTTP proxy
:param wsuri: the original WebSocket URI to connect to
:param ssl: an optional :class:`~ssl.SSLContext` with
TLS settings for the proxied HTTPS connection; ``None``
for not allowing TLS
:raises ValueError: if the proxy returns an error code

"""
request_headers = Headers()

if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover
request_headers["Host"] = wsuri.host
else:
request_headers["Host"] = f"{wsuri.host}:{wsuri.port}"

if proxy_uri.user_info:
request_headers["Proxy-Authorization"] = build_authorization_basic(
*proxy_uri.user_info
)

logger.debug("%s > CONNECT %s HTTP/1.1", self.side, f"{wsuri.host}:{wsuri.port}")
logger.debug("%s > %r", self.side, request_headers)

request = f"CONNECT {wsuri.host}:{wsuri.port} HTTP/1.1\r\n"
request += str(request_headers)

self.transport.write(request.encode())

try:
status_code, reason, headers = await read_response(self.reader)
except asyncio.CancelledError: # pragma: no cover
raise
except Exception as exc:
raise InvalidMessage("did not receive a valid HTTP response") from exc

logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason)
logger.debug("%s < %r", self.side, headers)

if not 200 <= status_code < 300:
# TODO improve error handling
raise ValueError(f"proxy error: HTTP {status_code} {reason}")

if ssl is not None:
transport = await self.loop.start_tls(
self.transport,
self,
sslcontext=create_default_context(Purpose.SERVER_AUTH) if isinstance(ssl, bool) else ssl,
server_side=False,
server_hostname=server_hostname
)
self.reader = asyncio.StreamReader(limit=self.read_limit // 2, loop=self.loop)
self.connection_made(transport)

async def handshake(
self,
wsuri: WebSocketURI,
Expand Down Expand Up @@ -360,6 +428,12 @@ class Connect:
:class:`~websockets.http.Headers` instance, a
:class:`~collections.abc.Mapping`, or an iterable of ``(name, value)``
pairs
* ``proxy_uri`` defines the HTTP proxy for establishing the connection; by
default, :func:`connect` uses proxies configured in the environment or
the system (see :func:`~urllib.request.getproxies` for details); set
``proxy_uri`` to ``None`` to disable this behavior
* ``proxy_ssl`` may be set to a :class:`~ssl.SSLContext` to enforce TLS
settings for connecting to a ``https://`` proxy; it defaults to ``True``

:raises ~websockets.uri.InvalidURI: if ``uri`` is invalid
:raises ~websockets.handshake.InvalidHandshake: if the opening handshake
Expand Down Expand Up @@ -391,8 +465,14 @@ def __init__(
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
subprotocols: Optional[Sequence[Subprotocol]] = None,
extra_headers: Optional[HeadersLike] = None,
proxy_uri: Union[str, object] = USE_SYSTEM_PROXY,
proxy_ssl: Optional[Union[SSLContext, bool]] = None,
**kwargs: Any,
) -> None:
conn_host: Optional[str]
conn_port: Optional[int]
conn_ssl: Optional[Union[SSLContext, bool]]

# Backwards compatibility: close_timeout used to be called timeout.
if timeout is None:
timeout = 10
Expand Down Expand Up @@ -423,6 +503,35 @@ def __init__(
"use a wss:// URI to enable TLS"
)

if proxy_uri is USE_SYSTEM_PROXY:
proxies = urllib.request.getproxies()
if urllib.request.proxy_bypass(f"{wsuri.host}:{wsuri.port}"):
proxy_uri = None
else:
# RFC 6455 recommends to prefer the proxy configured for HTTPS
# connections over the proxy configured for HTTP connections.
proxy_uri = proxies.get("https")
if proxy_uri is None and not wsuri.secure:
proxy_uri = proxies.get("http")

if proxy_uri is not None:
proxy_uri = parse_proxy_uri(proxy_uri)
if proxy_uri.secure:
if proxy_ssl is None:
proxy_ssl = True
elif proxy_ssl is not None:
raise ValueError(
"connect() received a TLS/SSL context for an HTTP proxy; "
"use an HTTPS proxy to enable TLS"
)
conn_host, conn_port, conn_ssl = proxy_uri.host, proxy_uri.port, proxy_ssl
else:
conn_host, conn_port, conn_ssl = wsuri.host, wsuri.port, kwargs.get("ssl")

self._ssl = kwargs.pop("ssl", None)
if proxy_uri is not None:
self._server_hostname = kwargs.pop("server_hostname", None)

if compression == "deflate":
if extensions is None:
extensions = []
Expand Down Expand Up @@ -457,26 +566,23 @@ def __init__(
)

if path is None:
host: Optional[str]
port: Optional[int]
if kwargs.get("sock") is None:
host, port = wsuri.host, wsuri.port
else:
if kwargs.get("sock") is not None:
# If sock is given, host and port shouldn't be specified.
host, port = None, None
conn_host, conn_port = None, None
# If host and port are given, override values from the URI.
host = kwargs.pop("host", host)
port = kwargs.pop("port", port)
conn_host = kwargs.pop("host", conn_host)
conn_port = kwargs.pop("port", conn_port)
create_connection = functools.partial(
loop.create_connection, factory, host, port, **kwargs
loop.create_connection, factory, conn_host, conn_port, ssl=conn_ssl, **kwargs
)
else:
create_connection = functools.partial(
loop.create_unix_connection, factory, path, **kwargs
loop.create_unix_connection, factory, path, ssl=conn_ssl, **kwargs
)

# This is a coroutine function.
self._create_connection = create_connection
self._proxy_uri = proxy_uri
self._wsuri = wsuri

def handle_redirect(self, uri: str) -> None:
Expand Down Expand Up @@ -541,6 +647,13 @@ async def __await_impl__(self) -> WebSocketClientProtocol:

try:
try:
if self._proxy_uri is not None:
await protocol.proxy_connect(
self._proxy_uri,
self._wsuri,
self._ssl,
self._server_hostname,
)
await protocol.handshake(
self._wsuri,
origin=protocol.origin,
Expand Down
62 changes: 61 additions & 1 deletion src/websockets/uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from .exceptions import InvalidURI


__all__ = ["parse_uri", "WebSocketURI"]
__all__ = [
"parse_uri", "WebSocketURI",
"parse_proxy_uri", "ProxyURI",
]


# Consider converting to a dataclass when dropping support for Python < 3.7.
Expand Down Expand Up @@ -79,3 +82,60 @@ def parse_uri(uri: str) -> WebSocketURI:
raise InvalidURI(uri)
user_info = (parsed.username, parsed.password)
return WebSocketURI(secure, host, port, resource_name, user_info)

class ProxyURI(NamedTuple):
"""
Proxy URI.

:param bool secure: tells whether to connect to the proxy with TLS
:param str host: lower-case host
:param int port: port, always set even if it's the default
:param str user_info: ``(username, password)`` tuple when the URI contains
`User Information`_, else ``None``.

.. _User Information: https://tools.ietf.org/html/rfc3986#section-3.2.1
"""

secure: bool
host: str
port: int
user_info: Optional[Tuple[str, str]]

# Work around https://bugs.python.org/issue19931

ProxyURI.secure.__doc__ = ""
ProxyURI.host.__doc__ = ""
ProxyURI.port.__doc__ = ""
ProxyURI.user_info.__doc__ = ""

def parse_proxy_uri(uri: str) -> ProxyURI:
"""
Parse and validate an HTTP proxy URI.

:raises ValueError: if ``uri`` isn't a valid HTTP proxy URI.

"""
parsed = urllib.parse.urlparse(uri)
try:
assert parsed.scheme in ['http', 'https']
assert parsed.hostname is not None
assert parsed.path == '' or parsed.path == '/'
assert parsed.params == ''
assert parsed.query == ''
assert parsed.fragment == ''
except AssertionError as exc:
raise InvalidURI(uri) from exc

secure = parsed.scheme == 'https'
host = parsed.hostname
port = parsed.port or (443 if secure else 80)
user_info = None
if parsed.username is not None:
# urllib.parse.urlparse accepts URLs with a username but without a
# password. This doesn't make sense for HTTP Basic Auth credentials.
if parsed.password is None:
raise InvalidURI(uri)
user_info = (parsed.username, parsed.password)
return ProxyURI(secure, host, port, user_info)


27 changes: 25 additions & 2 deletions tests/test_uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,38 @@
"ws://user@localhost/",
]

VALID_PROXY_URIS = [
("http://localhost", (False, "localhost", 80, None)),
("http://localhost/", (False, "localhost", 80, None)),
("https://localhost", (True, "localhost", 443, None)),
("http://user:pass@localhost", (False, "localhost", 80, ("user", "pass"))),
]

INVALID_PROXY_URIS = [
"http://localhost/path",
"ws://localhost/",
"wss://localhost/",
]

class URITests(unittest.TestCase):
def test_success(self):
def test_parse_uri_success(self):
for uri, parsed in VALID_URIS:
with self.subTest(uri=uri):
self.assertEqual(parse_uri(uri), parsed)

def test_error(self):
def test_parse_uri_error(self):
for uri in INVALID_URIS:
with self.subTest(uri=uri):
with self.assertRaises(InvalidURI):
parse_uri(uri)

def test_parse_proxy_uri_success(self):
for uri, parsed in VALID_PROXY_URIS:
with self.subTest(uri=uri):
self.assertEqual(parse_proxy_uri(uri), parsed)

def test_parse_proxy_uri_error(self):
for uri in INVALID_PROXY_URIS:
with self.subTest(uri=uri):
with self.assertRaises(InvalidURI):
parse_proxy_uri(uri)