Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add tests for blacklisting reactor/agent. (#9563)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Mar 11, 2021
1 parent 70d1b6a commit e55bd0e
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 14 deletions.
1 change: 1 addition & 0 deletions changelog.d/9563.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper.
26 changes: 14 additions & 12 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
IHostResolution,
Expand Down Expand Up @@ -151,16 +152,17 @@ def __init__(
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:

r = recv()
addresses = [] # type: List[IAddress]

def _callback() -> None:
r.resolutionBegan(None)

has_bad_ip = False
for i in addresses:
ip_address = IPAddress(i.host)
for address in addresses:
# We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
# should go through this path.
if not isinstance(address, (IPv4Address, IPv6Address)):
continue

ip_address = IPAddress(address.host)

if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
Expand All @@ -175,15 +177,15 @@ def _callback() -> None:
# request, but all we can really do from here is claim that there were no
# valid results.
if not has_bad_ip:
for i in addresses:
r.addressResolved(i)
r.resolutionComplete()
for address in addresses:
recv.addressResolved(address)
recv.resolutionComplete()

@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass
recv.resolutionBegan(resolutionInProgress)

@staticmethod
def addressResolved(address: IAddress) -> None:
Expand All @@ -197,7 +199,7 @@ def resolutionComplete() -> None:
EndpointReceiver, hostname, portNumber=portNumber
)

return r
return recv


@implementer(ISynapseReactor)
Expand Down Expand Up @@ -346,7 +348,7 @@ def __init__(
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
use_proxy=use_proxy,
)
) # type: IAgent

if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent
Expand Down
126 changes: 124 additions & 2 deletions tests/http/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,23 @@

from mock import Mock

from netaddr import IPSet

from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.client import Agent, ResponseDone
from twisted.web.iweb import UNKNOWN_LENGTH

from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
from synapse.api.errors import SynapseError
from synapse.http.client import (
BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
BodyExceededMaxSize,
read_body_with_max_size,
)

from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase


Expand Down Expand Up @@ -119,3 +130,114 @@ def test_content_length(self):

# The data is never consumed.
self.assertEqual(result.getvalue(), b"")


class BlacklistingAgentTest(TestCase):
def setUp(self):
self.reactor, self.clock = get_clock()

self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"

# Configure the reactor's DNS resolver.
for (domain, ip) in (
(self.safe_domain, self.safe_ip),
(self.unsafe_domain, self.unsafe_ip),
(self.allowed_domain, self.allowed_ip),
):
self.reactor.lookups[domain.decode()] = ip.decode()
self.reactor.lookups[ip.decode()] = ip.decode()

self.ip_whitelist = IPSet([self.allowed_ip.decode()])
self.ip_blacklist = IPSet(["5.0.0.0/8"])

def test_reactor(self):
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
agent = Agent(
BlacklistingReactorWrapper(
self.reactor,
ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
),
)

# The unsafe domains and IPs should be rejected.
for domain in (self.unsafe_domain, self.unsafe_ip):
self.failureResultOf(
agent.request(b"GET", b"http://" + domain), DNSLookupError
)

# The safe domains IPs should be accepted.
for domain in (
self.safe_domain,
self.allowed_domain,
self.safe_ip,
self.allowed_ip,
):
d = agent.request(b"GET", b"http://" + domain)

# Grab the latest TCP connection.
(
host,
port,
client_factory,
_timeout,
_bindAddress,
) = self.reactor.tcpClients[-1]

# Make the connection and pump data through it.
client = client_factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
)

response = self.successResultOf(d)
self.assertEqual(response.code, 200)

def test_agent(self):
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
agent = BlacklistingAgentWrapper(
Agent(self.reactor),
ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
)

# The unsafe IPs should be rejected.
self.failureResultOf(
agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
)

# The safe and unsafe domains and safe IPs should be accepted.
for domain in (
self.safe_domain,
self.unsafe_domain,
self.allowed_domain,
self.safe_ip,
self.allowed_ip,
):
d = agent.request(b"GET", b"http://" + domain)

# Grab the latest TCP connection.
(
host,
port,
client_factory,
_timeout,
_bindAddress,
) = self.reactor.tcpClients[-1]

# Make the connection and pump data through it.
client = client_factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
)

response = self.successResultOf(d)
self.assertEqual(response.code, 200)

0 comments on commit e55bd0e

Please sign in to comment.