diff --git a/.circleci/config.yml b/.circleci/config.yml index 0877c161a..68d02416d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -31,7 +31,7 @@ jobs: - run: tox -e py37 py38: docker: - - image: circleci/python:3.8.0rc1 + - image: circleci/python:3.8 steps: # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc diff --git a/docs/api.rst b/docs/api.rst index d265a91c2..7210d7014 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -55,7 +55,7 @@ Client .. automodule:: websockets.client - .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, proxy_uri=USE_SYSTEM_PROXY, proxy_ssl=None, **kwds) :async: .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) diff --git a/setup.cfg b/setup.cfg index c306b2d4f..72f9f27af 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py36.py37 +python-tag = py36.py37.py38 [metadata] license_file = LICENSE diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index ea1d829a3..a14327ab9 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -36,9 +36,11 @@ "InvalidURI", "NegotiationError", "Origin", + "parse_proxy_uri", "parse_uri", "PayloadTooBig", "ProtocolError", + "ProxyURI", "RedirectHandshake", "SecurityError", "serve", diff --git a/src/websockets/client.py b/src/websockets/client.py index be055310d..0bad41927 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -7,9 +7,11 @@ import collections.abc import functools import logging +import urllib.request import warnings +from ssl import Purpose, SSLContext, create_default_context from types import TracebackType -from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast +from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, Union, cast from .exceptions import ( InvalidHandshake, @@ -33,11 +35,13 @@ from .http import USER_AGENT, Headers, HeadersLike, read_response from .protocol import WebSocketCommonProtocol from .typing import ExtensionHeader, Origin, Subprotocol -from .uri import WebSocketURI, parse_uri +from .uri import ProxyURI, WebSocketURI, parse_proxy_uri, parse_uri __all__ = ["connect", "unix_connect", "WebSocketClientProtocol"] +USE_SYSTEM_PROXY = object() + logger = logging.getLogger(__name__) @@ -223,6 +227,70 @@ def process_subprotocol( return subprotocol + async def proxy_connect( + self, + proxy_uri: ProxyURI, + wsuri: WebSocketURI, + ssl: Optional[Union[SSLContext, bool]] = None, + server_hostname: Optional[str] = None, + ) -> None: + """ + Issue a CONNECT request, read a response and upgrade the connection + to TLS, if necessary. + + :param proxy_uri: the URI of the HTTP proxy + :param wsuri: the original WebSocket URI to connect to + :param ssl: an optional :class:`~ssl.SSLContext` with + TLS settings for the proxied HTTPS connection; ``None`` + for not allowing TLS + :raises ValueError: if the proxy returns an error code + + """ + request_headers = Headers() + + if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover + request_headers["Host"] = wsuri.host + else: + request_headers["Host"] = f"{wsuri.host}:{wsuri.port}" + + if proxy_uri.user_info: + request_headers["Proxy-Authorization"] = build_authorization_basic( + *proxy_uri.user_info + ) + + logger.debug("%s > CONNECT %s HTTP/1.1", self.side, f"{wsuri.host}:{wsuri.port}") + logger.debug("%s > %r", self.side, request_headers) + + request = f"CONNECT {wsuri.host}:{wsuri.port} HTTP/1.1\r\n" + request += str(request_headers) + + self.transport.write(request.encode()) + + try: + status_code, reason, headers = await read_response(self.reader) + except asyncio.CancelledError: # pragma: no cover + raise + except Exception as exc: + raise InvalidMessage("did not receive a valid HTTP response") from exc + + logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason) + logger.debug("%s < %r", self.side, headers) + + if not 200 <= status_code < 300: + # TODO improve error handling + raise ValueError(f"proxy error: HTTP {status_code} {reason}") + + if ssl is not None: + transport = await self.loop.start_tls( + self.transport, + self, + sslcontext=create_default_context(Purpose.SERVER_AUTH) if isinstance(ssl, bool) else ssl, + server_side=False, + server_hostname=server_hostname + ) + self.reader = asyncio.StreamReader(limit=self.read_limit // 2, loop=self.loop) + self.connection_made(transport) + async def handshake( self, wsuri: WebSocketURI, @@ -360,6 +428,12 @@ class Connect: :class:`~websockets.http.Headers` instance, a :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` pairs + * ``proxy_uri`` defines the HTTP proxy for establishing the connection; by + default, :func:`connect` uses proxies configured in the environment or + the system (see :func:`~urllib.request.getproxies` for details); set + ``proxy_uri`` to ``None`` to disable this behavior + * ``proxy_ssl`` may be set to a :class:`~ssl.SSLContext` to enforce TLS + settings for connecting to a ``https://`` proxy; it defaults to ``True`` :raises ~websockets.uri.InvalidURI: if ``uri`` is invalid :raises ~websockets.handshake.InvalidHandshake: if the opening handshake @@ -391,8 +465,14 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, + proxy_uri: Union[str, object] = USE_SYSTEM_PROXY, + proxy_ssl: Optional[Union[SSLContext, bool]] = None, **kwargs: Any, ) -> None: + conn_host: Optional[str] + conn_port: Optional[int] + conn_ssl: Optional[Union[SSLContext, bool]] + # Backwards compatibility: close_timeout used to be called timeout. if timeout is None: timeout = 10 @@ -423,6 +503,35 @@ def __init__( "use a wss:// URI to enable TLS" ) + if proxy_uri is USE_SYSTEM_PROXY: + proxies = urllib.request.getproxies() + if urllib.request.proxy_bypass(f"{wsuri.host}:{wsuri.port}"): + proxy_uri = None + else: + # RFC 6455 recommends to prefer the proxy configured for HTTPS + # connections over the proxy configured for HTTP connections. + proxy_uri = proxies.get("https") + if proxy_uri is None and not wsuri.secure: + proxy_uri = proxies.get("http") + + if proxy_uri is not None: + proxy_uri = parse_proxy_uri(proxy_uri) + if proxy_uri.secure: + if proxy_ssl is None: + proxy_ssl = True + elif proxy_ssl is not None: + raise ValueError( + "connect() received a TLS/SSL context for an HTTP proxy; " + "use an HTTPS proxy to enable TLS" + ) + conn_host, conn_port, conn_ssl = proxy_uri.host, proxy_uri.port, proxy_ssl + else: + conn_host, conn_port, conn_ssl = wsuri.host, wsuri.port, kwargs.get("ssl") + + self._ssl = kwargs.pop("ssl", None) + if proxy_uri is not None: + self._server_hostname = kwargs.pop("server_hostname", None) + if compression == "deflate": if extensions is None: extensions = [] @@ -457,26 +566,23 @@ def __init__( ) if path is None: - host: Optional[str] - port: Optional[int] - if kwargs.get("sock") is None: - host, port = wsuri.host, wsuri.port - else: + if kwargs.get("sock") is not None: # If sock is given, host and port shouldn't be specified. - host, port = None, None + conn_host, conn_port = None, None # If host and port are given, override values from the URI. - host = kwargs.pop("host", host) - port = kwargs.pop("port", port) + conn_host = kwargs.pop("host", conn_host) + conn_port = kwargs.pop("port", conn_port) create_connection = functools.partial( - loop.create_connection, factory, host, port, **kwargs + loop.create_connection, factory, conn_host, conn_port, ssl=conn_ssl, **kwargs ) else: create_connection = functools.partial( - loop.create_unix_connection, factory, path, **kwargs + loop.create_unix_connection, factory, path, ssl=conn_ssl, **kwargs ) # This is a coroutine function. self._create_connection = create_connection + self._proxy_uri = proxy_uri self._wsuri = wsuri def handle_redirect(self, uri: str) -> None: @@ -541,6 +647,13 @@ async def __await_impl__(self) -> WebSocketClientProtocol: try: try: + if self._proxy_uri is not None: + await protocol.proxy_connect( + self._proxy_uri, + self._wsuri, + self._ssl, + self._server_hostname, + ) await protocol.handshake( self._wsuri, origin=protocol.origin, diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 6669e5668..110e5f452 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -13,7 +13,10 @@ from .exceptions import InvalidURI -__all__ = ["parse_uri", "WebSocketURI"] +__all__ = [ + "parse_uri", "WebSocketURI", + "parse_proxy_uri", "ProxyURI", +] # Consider converting to a dataclass when dropping support for Python < 3.7. @@ -79,3 +82,60 @@ def parse_uri(uri: str) -> WebSocketURI: raise InvalidURI(uri) user_info = (parsed.username, parsed.password) return WebSocketURI(secure, host, port, resource_name, user_info) + +class ProxyURI(NamedTuple): + """ + Proxy URI. + + :param bool secure: tells whether to connect to the proxy with TLS + :param str host: lower-case host + :param int port: port, always set even if it's the default + :param str user_info: ``(username, password)`` tuple when the URI contains + `User Information`_, else ``None``. + + .. _User Information: https://tools.ietf.org/html/rfc3986#section-3.2.1 + """ + + secure: bool + host: str + port: int + user_info: Optional[Tuple[str, str]] + +# Work around https://bugs.python.org/issue19931 + +ProxyURI.secure.__doc__ = "" +ProxyURI.host.__doc__ = "" +ProxyURI.port.__doc__ = "" +ProxyURI.user_info.__doc__ = "" + +def parse_proxy_uri(uri: str) -> ProxyURI: + """ + Parse and validate an HTTP proxy URI. + + :raises ValueError: if ``uri`` isn't a valid HTTP proxy URI. + + """ + parsed = urllib.parse.urlparse(uri) + try: + assert parsed.scheme in ['http', 'https'] + assert parsed.hostname is not None + assert parsed.path == '' or parsed.path == '/' + assert parsed.params == '' + assert parsed.query == '' + assert parsed.fragment == '' + except AssertionError as exc: + raise InvalidURI(uri) from exc + + secure = parsed.scheme == 'https' + host = parsed.hostname + port = parsed.port or (443 if secure else 80) + user_info = None + if parsed.username is not None: + # urllib.parse.urlparse accepts URLs with a username but without a + # password. This doesn't make sense for HTTP Basic Auth credentials. + if parsed.password is None: + raise InvalidURI(uri) + user_info = (parsed.username, parsed.password) + return ProxyURI(secure, host, port, user_info) + + diff --git a/tests/test_uri.py b/tests/test_uri.py index e41860b8e..8f8c7c7a0 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -19,15 +19,38 @@ "ws://user@localhost/", ] +VALID_PROXY_URIS = [ + ("http://localhost", (False, "localhost", 80, None)), + ("http://localhost/", (False, "localhost", 80, None)), + ("https://localhost", (True, "localhost", 443, None)), + ("http://user:pass@localhost", (False, "localhost", 80, ("user", "pass"))), +] + +INVALID_PROXY_URIS = [ + "http://localhost/path", + "ws://localhost/", + "wss://localhost/", +] class URITests(unittest.TestCase): - def test_success(self): + def test_parse_uri_success(self): for uri, parsed in VALID_URIS: with self.subTest(uri=uri): self.assertEqual(parse_uri(uri), parsed) - def test_error(self): + def test_parse_uri_error(self): for uri in INVALID_URIS: with self.subTest(uri=uri): with self.assertRaises(InvalidURI): parse_uri(uri) + + def test_parse_proxy_uri_success(self): + for uri, parsed in VALID_PROXY_URIS: + with self.subTest(uri=uri): + self.assertEqual(parse_proxy_uri(uri), parsed) + + def test_parse_proxy_uri_error(self): + for uri in INVALID_PROXY_URIS: + with self.subTest(uri=uri): + with self.assertRaises(InvalidURI): + parse_proxy_uri(uri)