Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix tests on Twisted trunk. #16528

Merged
merged 11 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16528.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix running unit tests on Twisted trunk.
37 changes: 35 additions & 2 deletions tests/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
import subprocess
from typing import List

from incremental import Version
from zope.interface import implementer

import twisted
from OpenSSL import SSL
from OpenSSL.SSL import Connection
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.interfaces import (
IOpenSSLServerConnectionCreator,
IProtocolFactory,
IReactorTime,
)
from twisted.internet.ssl import Certificate, trustRootFromCertificates
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401

Expand Down Expand Up @@ -153,6 +159,33 @@ def serverConnectionForTLS(self, tlsProtocol: TLSMemoryBIOProtocol) -> Connectio
return Connection(ctx, None)


def wrap_server_factory_for_tls(
factory: IProtocolFactory, clock: IReactorTime, sanlist: List[bytes]
) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory

The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`

Args:
factory: protocol factory to wrap
sanlist: list of domains the cert should be valid for

Returns:
interfaces.IProtocolFactory
"""
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
# Twisted > 23.8.0 has a different API that accepts a clock.
if twisted.version <= Version("Twisted", 23, 8, 0):
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory
)
else:
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory, clock=clock # type: ignore[call-arg]
)
Comment on lines +178 to +186
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did Twisted make a breaking change to TLSMemoryBIOFactory here, or are we just taking advantage of a newer API?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

@clokep clokep Oct 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't a breaking change, you don't have to provide clock, but you do if you want our tests to pass. (Otherwise it uses the trial reactor instead of the test reactor.)



# A dummy address, useful for tests that use FakeTransport and don't care about where
# packets are going to/coming from.
dummy_address = IPv4Address("TCP", "127.0.0.1", 80)
60 changes: 24 additions & 36 deletions tests/http/federation/test_matrix_federation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
IProtocolFactory,
)
from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent
from twisted.web.http import HTTPChannel, Request
Expand All @@ -57,11 +57,7 @@
from synapse.util.caches.ttlcache import TTLCache

from tests import unittest
from tests.http import (
TestServerTLSConnectionFactory,
dummy_address,
get_test_ca_cert_file,
)
from tests.http import dummy_address, get_test_ca_cert_file, wrap_server_factory_for_tls
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.utils import checked_cast, default_config

Expand Down Expand Up @@ -125,7 +121,18 @@ def _make_connection(
# build the test server
server_factory = _get_test_protocol_factory()
if ssl:
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
server_factory = wrap_server_factory_for_tls(
server_factory,
self.reactor,
tls_sanlist
or [
b"DNS:testserv",
b"DNS:target-server",
b"DNS:xn--bcher-kva.com",
b"IP:1.2.3.4",
b"IP:::1",
],
)

server_protocol = server_factory.buildProtocol(dummy_address)
assert server_protocol is not None
Expand Down Expand Up @@ -435,8 +442,16 @@ def _do_get_via_proxy(
request.finish()

# now we make another test server to act as the upstream HTTP server.
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
server_ssl_protocol = wrap_server_factory_for_tls(
_get_test_protocol_factory(),
self.reactor,
sanlist=[
b"DNS:testserv",
b"DNS:target-server",
b"DNS:xn--bcher-kva.com",
b"IP:1.2.3.4",
b"IP:::1",
],
).buildProtocol(dummy_address)

# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
Expand Down Expand Up @@ -1786,33 +1801,6 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))


def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
Args:
factory: protocol factory to wrap
sanlist: list of domains the cert should be valid for
Returns:
interfaces.IProtocolFactory
"""
if sanlist is None:
sanlist = [
b"DNS:testserv",
b"DNS:target-server",
b"DNS:xn--bcher-kva.com",
b"IP:1.2.3.4",
b"IP:::1",
]

connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory
)


def _get_test_protocol_factory() -> IProtocolFactory:
"""Get a protocol Factory which will build an HTTPChannel
Returns:
Expand Down
44 changes: 10 additions & 34 deletions tests/http/test_proxyagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,14 @@
)
from twisted.internet.interfaces import IProtocol, IProtocolFactory
from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel

from synapse.http.client import BlocklistingReactorWrapper
from synapse.http.connectproxyclient import BasicProxyCredentials
from synapse.http.proxyagent import ProxyAgent, parse_proxy

from tests.http import (
TestServerTLSConnectionFactory,
dummy_address,
get_test_https_policy,
)
from tests.http import dummy_address, get_test_https_policy, wrap_server_factory_for_tls
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
from tests.utils import checked_cast
Expand Down Expand Up @@ -251,7 +247,9 @@ def _make_connection(
the server Protocol returned by server_factory
"""
if ssl:
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
server_factory = wrap_server_factory_for_tls(
server_factory, self.reactor, tls_sanlist or [b"DNS:test.com"]
)

server_protocol = server_factory.buildProtocol(dummy_address)
assert server_protocol is not None
Expand Down Expand Up @@ -618,8 +616,8 @@ def _do_https_request_via_proxy(
request.finish()

# now we make another test server to act as the upstream HTTP server.
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
server_ssl_protocol = wrap_server_factory_for_tls(
_get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"]
Comment on lines -621 to +620
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is passing in an explicit sanlist here a meaningful change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto below in the rest of this test

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No: _wrap_server_factory_for_tls defaulted to this if no sanlist argument was provided.

).buildProtocol(dummy_address)

# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
Expand Down Expand Up @@ -785,7 +783,9 @@ def test_https_request_via_uppercase_proxy_with_blocklist(self) -> None:
request.finish()

# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
ssl_factory = wrap_server_factory_for_tls(
_get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"]
)
ssl_protocol = ssl_factory.buildProtocol(dummy_address)
assert isinstance(ssl_protocol, TLSMemoryBIOProtocol)
http_server = ssl_protocol.wrappedProtocol
Expand Down Expand Up @@ -849,30 +849,6 @@ def test_proxy_with_https_scheme(self) -> None:
self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)


def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory

The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`

Args:
factory: protocol factory to wrap
sanlist: list of domains the cert should be valid for

Returns:
interfaces.IProtocolFactory
"""
if sanlist is None:
sanlist = [b"DNS:test.com"]

connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory
)


def _get_test_protocol_factory() -> IProtocolFactory:
"""Get a protocol Factory which will build an HTTPChannel

Expand Down
52 changes: 13 additions & 39 deletions tests/replication/test_multi_media_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
import os
from typing import Any, Optional, Tuple

from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http import HTTPChannel
from twisted.web.server import Request
Expand All @@ -27,7 +25,11 @@
from synapse.server import HomeServer
from synapse.util import Clock

from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.http import (
TestServerTLSConnectionFactory,
get_test_ca_cert_file,
wrap_server_factory_for_tls,
)
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeTransport, make_request
from tests.test_utils import SMALL_PNG
Expand Down Expand Up @@ -94,7 +96,13 @@ def _get_media_req(
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()

# build the test server
server_tls_protocol = _build_test_server(get_connection_factory())
server_factory = Factory.forProtocol(HTTPChannel)
# Request.finish expects the factory to have a 'log' method.
server_factory.log = _log_request

server_tls_protocol = wrap_server_factory_for_tls(
server_factory, self.reactor, sanlist=[b"DNS:example.com"]
).buildProtocol(None)

# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
Expand All @@ -114,7 +122,7 @@ def _get_media_req(
)

# fish the test server back out of the server-side TLS protocol.
http_server: HTTPChannel = server_tls_protocol.wrappedProtocol # type: ignore[assignment]
http_server: HTTPChannel = server_tls_protocol.wrappedProtocol

# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
Expand Down Expand Up @@ -240,40 +248,6 @@ def _count_remote_thumbnails(self) -> int:
return sum(len(files) for _, _, files in os.walk(path))


def get_connection_factory() -> TestServerTLSConnectionFactory:
# this needs to happen once, but not until we are ready to run the first test
global test_server_connection_factory
if test_server_connection_factory is None:
test_server_connection_factory = TestServerTLSConnectionFactory(
sanlist=[b"DNS:example.com"]
)
return test_server_connection_factory


def _build_test_server(
connection_creator: IOpenSSLServerConnectionCreator,
) -> TLSMemoryBIOProtocol:
"""Construct a test server

This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol

Args:
connection_creator: thing to build SSL connections

Returns:
TLSMemoryBIOProtocol
"""
server_factory = Factory.forProtocol(HTTPChannel)
# Request.finish expects the factory to have a 'log' method.
server_factory.log = _log_request

server_tls_factory = TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=server_factory
)

return server_tls_factory.buildProtocol(None)


def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)
12 changes: 12 additions & 0 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
from unittest.mock import Mock

import attr
from incremental import Version
from typing_extensions import ParamSpec
from zope.interface import implementer

import twisted
from twisted.internet import address, tcp, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
Expand Down Expand Up @@ -474,6 +476,16 @@ def getHostByName(
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name])

# In order for the TLS protocol tests to work, modify _get_default_clock
# on newer Twisted versions to use the test reactor's clock.
#
# This is *super* dirty since it is never undone and relies on the next
# test to overwrite it.
if twisted.version > Version("Twisted", 23, 8, 0):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels janky that we check > here and <= to above... as if there is at 23.8.x then it'll enable one bit of code and not the other.

It sounds like a Twisted 23.1x.0 will be out soon-ish though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can use >= 23.10.0 if we want (in both places), but it'll fail until the release is out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 23.8 is fine, just a note for completeness.

from twisted.protocols import tls

tls._get_default_clock = lambda: self # type: ignore[attr-defined]

self.nameResolver = SimpleResolverComplexifier(FakeResolver())
super().__init__()

Expand Down
Loading