Skip to content

Commit

Permalink
Rename ssl_context to ssl in sync implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Jan 21, 2024
1 parent 908c7ba commit e21811e
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 46 deletions.
14 changes: 13 additions & 1 deletion docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,23 @@ fixing regressions shortly after a release.
Only documented APIs are public. Undocumented, private APIs may change without
notice.

12.1
13.0
----

*In development*

Backwards-incompatible changes
..............................

.. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect`
and :func:`~sync.server.serve` is renamed to ``ssl``.
:class: note

This aligns the API of the :mod:`threading` implementation with the
:mod:`asyncio` implementation.

For backwards compatibility, ``ssl_context`` is still supported.

New features
............

Expand Down
27 changes: 17 additions & 10 deletions src/websockets/sync/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import socket
import ssl
import ssl as ssl_module
import threading
import warnings
from typing import Any, Optional, Sequence, Type

from ..client import ClientProtocol
Expand Down Expand Up @@ -128,7 +129,7 @@ def connect(
*,
# TCP/TLS
sock: Optional[socket.socket] = None,
ssl_context: Optional[ssl.SSLContext] = None,
ssl: Optional[ssl_module.SSLContext] = None,
server_hostname: Optional[str] = None,
# WebSocket
origin: Optional[Origin] = None,
Expand Down Expand Up @@ -166,7 +167,7 @@ def connect(
sock: Preexisting TCP socket. ``sock`` overrides the host and port
from ``uri``. You may call :func:`socket.create_connection` to
create a suitable TCP socket.
ssl_context: Configuration for enabling TLS on the connection.
ssl: Configuration for enabling TLS on the connection.
server_hostname: Host name for the TLS handshake. ``server_hostname``
overrides the host name from ``uri``.
origin: Value of the ``Origin`` header, for servers that require it.
Expand Down Expand Up @@ -207,9 +208,14 @@ def connect(

# Process parameters

# Backwards compatibility: ssl used to be called ssl_context.
if ssl is None and "ssl_context" in kwargs:
ssl = kwargs.pop("ssl_context")
warnings.warn("ssl_context was renamed to ssl", DeprecationWarning)

wsuri = parse_uri(uri)
if not wsuri.secure and ssl_context is not None:
raise TypeError("ssl_context argument is incompatible with a ws:// URI")
if not wsuri.secure and ssl is not None:
raise TypeError("ssl argument is incompatible with a ws:// URI")

# Private APIs for unix_connect()
unix: bool = kwargs.pop("unix", False)
Expand Down Expand Up @@ -259,12 +265,12 @@ def connect(
# Initialize TLS wrapper and perform TLS handshake

if wsuri.secure:
if ssl_context is None:
ssl_context = ssl.create_default_context()
if ssl is None:
ssl = ssl_module.create_default_context()
if server_hostname is None:
server_hostname = wsuri.host
sock.settimeout(deadline.timeout())
sock = ssl_context.wrap_socket(sock, server_hostname=server_hostname)
sock = ssl.wrap_socket(sock, server_hostname=server_hostname)
sock.settimeout(None)

# Initialize WebSocket connection
Expand Down Expand Up @@ -318,12 +324,13 @@ def unix_connect(
Args:
path: File system path to the Unix socket.
uri: URI of the WebSocket server. ``uri`` defaults to
``ws://localhost/`` or, when a ``ssl_context`` is provided, to
``ws://localhost/`` or, when a ``ssl`` is provided, to
``wss://localhost/``.
"""
if uri is None:
if kwargs.get("ssl_context") is None:
# Backwards compatibility: ssl used to be called ssl_context.
if kwargs.get("ssl") is None and kwargs.get("ssl_context") is None:
uri = "ws://localhost/"
else:
uri = "wss://localhost/"
Expand Down
21 changes: 14 additions & 7 deletions src/websockets/sync/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import os
import selectors
import socket
import ssl
import ssl as ssl_module
import sys
import threading
import warnings
from types import TracebackType
from typing import Any, Callable, Optional, Sequence, Type

Expand Down Expand Up @@ -268,7 +269,7 @@ def serve(
*,
# TCP/TLS
sock: Optional[socket.socket] = None,
ssl_context: Optional[ssl.SSLContext] = None,
ssl: Optional[ssl_module.SSLContext] = None,
# WebSocket
origins: Optional[Sequence[Optional[Origin]]] = None,
extensions: Optional[Sequence[ServerExtensionFactory]] = None,
Expand Down Expand Up @@ -337,7 +338,7 @@ def handler(websocket):
sock: Preexisting TCP socket. ``sock`` replaces ``host`` and ``port``.
You may call :func:`socket.create_server` to create a suitable TCP
socket.
ssl_context: Configuration for enabling TLS on the connection.
ssl: Configuration for enabling TLS on the connection.
origins: Acceptable values of the ``Origin`` header, for defending
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
in the list if the lack of an origin is acceptable.
Expand Down Expand Up @@ -386,6 +387,11 @@ def handler(websocket):

# Process parameters

# Backwards compatibility: ssl used to be called ssl_context.
if ssl is None and "ssl_context" in kwargs:
ssl = kwargs.pop("ssl_context")
warnings.warn("ssl_context was renamed to ssl", DeprecationWarning)

if subprotocols is not None:
validate_subprotocols(subprotocols)

Expand Down Expand Up @@ -417,8 +423,8 @@ def handler(websocket):

# Initialize TLS wrapper

if ssl_context is not None:
sock = ssl_context.wrap_socket(
if ssl is not None:
sock = ssl.wrap_socket(
sock,
server_side=True,
# Delay TLS handshake until after we set a timeout on the socket.
Expand All @@ -441,9 +447,10 @@ def conn_handler(sock: socket.socket, addr: Any) -> None:

# Perform TLS handshake

if ssl_context is not None:
if ssl is not None:
sock.settimeout(deadline.timeout())
assert isinstance(sock, ssl.SSLSocket) # mypy cannot figure this out
# mypy cannot figure this out
assert isinstance(sock, ssl_module.SSLSocket)
sock.do_handshake()
sock.settimeout(None)

Expand Down
3 changes: 2 additions & 1 deletion tests/sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs):
else:
assert isinstance(wsuri_or_server, WebSocketServer)
if secure is None:
secure = "ssl_context" in kwargs
# Backwards compatibility: ssl used to be called ssl_context.
secure = "ssl" in kwargs or "ssl_context" in kwargs
protocol = "wss" if secure else "ws"
host, port = wsuri_or_server.socket.getsockname()
wsuri = f"{protocol}://{host}:{port}{resource_name}"
Expand Down
47 changes: 27 additions & 20 deletions tests/sync/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from websockets.extensions.permessage_deflate import PerMessageDeflate
from websockets.sync.client import *

from ..utils import MS, temp_unix_socket_path
from ..utils import MS, DeprecationTestCase, temp_unix_socket_path
from .client import CLIENT_CONTEXT, run_client, run_unix_client
from .server import SERVER_CONTEXT, do_nothing, run_server, run_unix_server

Expand Down Expand Up @@ -137,36 +137,36 @@ def close_connection(self, request):
class SecureClientTests(unittest.TestCase):
def test_connection(self):
"""Client connects to server securely."""
with run_server(ssl_context=SERVER_CONTEXT) as server:
with run_client(server, ssl_context=CLIENT_CONTEXT) as client:
with run_server(ssl=SERVER_CONTEXT) as server:
with run_client(server, ssl=CLIENT_CONTEXT) as client:
self.assertEqual(client.protocol.state.name, "OPEN")
self.assertEqual(client.socket.version()[:3], "TLS")

def test_set_server_hostname_implicitly(self):
"""Client sets server_hostname to the host in the WebSocket URI."""
with temp_unix_socket_path() as path:
with run_unix_server(path, ssl_context=SERVER_CONTEXT):
with run_unix_server(path, ssl=SERVER_CONTEXT):
with run_unix_client(
path,
ssl_context=CLIENT_CONTEXT,
ssl=CLIENT_CONTEXT,
uri="wss://overridden/",
) as client:
self.assertEqual(client.socket.server_hostname, "overridden")

def test_set_server_hostname_explicitly(self):
"""Client sets server_hostname to the value provided in argument."""
with temp_unix_socket_path() as path:
with run_unix_server(path, ssl_context=SERVER_CONTEXT):
with run_unix_server(path, ssl=SERVER_CONTEXT):
with run_unix_client(
path,
ssl_context=CLIENT_CONTEXT,
ssl=CLIENT_CONTEXT,
server_hostname="overridden",
) as client:
self.assertEqual(client.socket.server_hostname, "overridden")

def test_reject_invalid_server_certificate(self):
"""Client rejects certificate where server certificate isn't trusted."""
with run_server(ssl_context=SERVER_CONTEXT) as server:
with run_server(ssl=SERVER_CONTEXT) as server:
with self.assertRaisesRegex(
ssl.SSLCertVerificationError,
r"certificate verify failed: self[ -]signed certificate",
Expand All @@ -177,15 +177,13 @@ def test_reject_invalid_server_certificate(self):

def test_reject_invalid_server_hostname(self):
"""Client rejects certificate where server hostname doesn't match."""
with run_server(ssl_context=SERVER_CONTEXT) as server:
with run_server(ssl=SERVER_CONTEXT) as server:
with self.assertRaisesRegex(
ssl.SSLCertVerificationError,
r"certificate verify failed: Hostname mismatch",
):
# This hostname isn't included in the test certificate.
with run_client(
server, ssl_context=CLIENT_CONTEXT, server_hostname="invalid"
):
with run_client(server, ssl=CLIENT_CONTEXT, server_hostname="invalid"):
self.fail("did not raise")


Expand All @@ -212,32 +210,32 @@ class SecureUnixClientTests(unittest.TestCase):
def test_connection(self):
"""Client connects to server securely over a Unix socket."""
with temp_unix_socket_path() as path:
with run_unix_server(path, ssl_context=SERVER_CONTEXT):
with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client:
with run_unix_server(path, ssl=SERVER_CONTEXT):
with run_unix_client(path, ssl=CLIENT_CONTEXT) as client:
self.assertEqual(client.protocol.state.name, "OPEN")
self.assertEqual(client.socket.version()[:3], "TLS")

def test_set_server_hostname(self):
"""Client sets server_hostname to the host in the WebSocket URI."""
# This is part of the documented behavior of unix_connect().
with temp_unix_socket_path() as path:
with run_unix_server(path, ssl_context=SERVER_CONTEXT):
with run_unix_server(path, ssl=SERVER_CONTEXT):
with run_unix_client(
path,
ssl_context=CLIENT_CONTEXT,
ssl=CLIENT_CONTEXT,
uri="wss://overridden/",
) as client:
self.assertEqual(client.socket.server_hostname, "overridden")


class ClientUsageErrorsTests(unittest.TestCase):
def test_ssl_context_without_secure_uri(self):
"""Client rejects ssl_context when URI isn't secure."""
def test_ssl_without_secure_uri(self):
"""Client rejects ssl when URI isn't secure."""
with self.assertRaisesRegex(
TypeError,
"ssl_context argument is incompatible with a ws:// URI",
"ssl argument is incompatible with a ws:// URI",
):
connect("ws://localhost/", ssl_context=CLIENT_CONTEXT)
connect("ws://localhost/", ssl=CLIENT_CONTEXT)

def test_unix_without_path_or_sock(self):
"""Unix client requires path when sock isn't provided."""
Expand Down Expand Up @@ -272,3 +270,12 @@ def test_unsupported_compression(self):
"unsupported compression: False",
):
connect("ws://localhost/", compression=False)


class BackwardsCompatibilityTests(DeprecationTestCase):
def test_ssl_context_argument(self):
"""Client supports the deprecated ssl_context argument."""
with run_server(ssl=SERVER_CONTEXT) as server:
with self.assertDeprecationWarning("ssl_context was renamed to ssl"):
with run_client(server, ssl_context=CLIENT_CONTEXT):
pass
23 changes: 16 additions & 7 deletions tests/sync/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from websockets.http11 import Request, Response
from websockets.sync.server import *

from ..utils import MS, temp_unix_socket_path
from ..utils import MS, DeprecationTestCase, temp_unix_socket_path
from .client import CLIENT_CONTEXT, run_client, run_unix_client
from .server import (
SERVER_CONTEXT,
Expand Down Expand Up @@ -274,20 +274,20 @@ def handler(sock, addr):
class SecureServerTests(EvalShellMixin, unittest.TestCase):
def test_connection(self):
"""Server receives secure connection from client."""
with run_server(ssl_context=SERVER_CONTEXT) as server:
with run_client(server, ssl_context=CLIENT_CONTEXT) as client:
with run_server(ssl=SERVER_CONTEXT) as server:
with run_client(server, ssl=CLIENT_CONTEXT) as client:
self.assertEval(client, "ws.protocol.state.name", "OPEN")
self.assertEval(client, "ws.socket.version()[:3]", "TLS")

def test_timeout_during_tls_handshake(self):
"""Server times out before receiving TLS handshake request from client."""
with run_server(ssl_context=SERVER_CONTEXT, open_timeout=MS) as server:
with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server:
with socket.create_connection(server.socket.getsockname()) as sock:
self.assertEqual(sock.recv(4096), b"")

def test_connection_closed_during_tls_handshake(self):
"""Server reads EOF before receiving TLS handshake request from client."""
with run_server(ssl_context=SERVER_CONTEXT) as server:
with run_server(ssl=SERVER_CONTEXT) as server:
# Patch handler to record a reference to the thread running it.
server_thread = None
conn_received = threading.Event()
Expand Down Expand Up @@ -325,8 +325,8 @@ class SecureUnixServerTests(EvalShellMixin, unittest.TestCase):
def test_connection(self):
"""Server receives secure connection from client over a Unix socket."""
with temp_unix_socket_path() as path:
with run_unix_server(path, ssl_context=SERVER_CONTEXT):
with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client:
with run_unix_server(path, ssl=SERVER_CONTEXT):
with run_unix_client(path, ssl=CLIENT_CONTEXT) as client:
self.assertEval(client, "ws.protocol.state.name", "OPEN")
self.assertEval(client, "ws.socket.version()[:3]", "TLS")

Expand Down Expand Up @@ -386,3 +386,12 @@ def test_shutdown(self):
# Check that the server socket is closed.
with self.assertRaises(OSError):
server.socket.accept()


class BackwardsCompatibilityTests(DeprecationTestCase):
def test_ssl_context_argument(self):
"""Client supports the deprecated ssl_context argument."""
with self.assertDeprecationWarning("ssl_context was renamed to ssl"):
with run_server(ssl_context=SERVER_CONTEXT) as server:
with run_client(server, ssl=CLIENT_CONTEXT):
pass

0 comments on commit e21811e

Please sign in to comment.