Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add tests for blacklisting reactor/agent. #9563

Merged
merged 6 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9563.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper.
26 changes: 14 additions & 12 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
IHostResolution,
Expand Down Expand Up @@ -151,16 +152,17 @@ def __init__(
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:

r = recv()
addresses = [] # type: List[IAddress]

def _callback() -> None:
r.resolutionBegan(None)

has_bad_ip = False
for i in addresses:
ip_address = IPAddress(i.host)
for address in addresses:
# We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
# should go through this path.
if not isinstance(address, (IPv4Address, IPv6Address)):
continue

ip_address = IPAddress(address.host)

if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
Expand All @@ -175,15 +177,15 @@ def _callback() -> None:
# request, but all we can really do from here is claim that there were no
# valid results.
if not has_bad_ip:
for i in addresses:
r.addressResolved(i)
r.resolutionComplete()
for address in addresses:
recv.addressResolved(address)
recv.resolutionComplete()
Comment on lines +181 to +182
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these recvs need to be recv()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned this in 315061f, but recv does not need to be instantiated. (It actually shouldn't be, it isn't a callable necessarily!) The only reason the current code works is that the implementation of EndpointReceiver happens to use a class with all static methods, so instantiating it essentially has the same methods called as if it isn't instantiated.

tl;dr recv is meant to be an instance, not a class.


@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass
recv.resolutionBegan(resolutionInProgress)

@staticmethod
def addressResolved(address: IAddress) -> None:
Expand All @@ -197,7 +199,7 @@ def resolutionComplete() -> None:
EndpointReceiver, hostname, portNumber=portNumber
)

return r
return recv


@implementer(ISynapseReactor)
Expand Down Expand Up @@ -346,7 +348,7 @@ def __init__(
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
use_proxy=use_proxy,
)
) # type: IAgent

if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent
Expand Down
126 changes: 124 additions & 2 deletions tests/http/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,23 @@

from mock import Mock

from netaddr import IPSet

from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.client import Agent, ResponseDone
from twisted.web.iweb import UNKNOWN_LENGTH

from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
from synapse.api.errors import SynapseError
from synapse.http.client import (
BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
BodyExceededMaxSize,
read_body_with_max_size,
)

from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase


Expand Down Expand Up @@ -119,3 +130,114 @@ def test_content_length(self):

# The data is never consumed.
self.assertEqual(result.getvalue(), b"")


class BlacklistingAgentTest(TestCase):
def setUp(self):
self.reactor, self.clock = get_clock()

self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"

# Configure the reactor's DNS resolver.
for (domain, ip) in (
(self.safe_domain, self.safe_ip),
(self.unsafe_domain, self.unsafe_ip),
(self.allowed_domain, self.allowed_ip),
):
self.reactor.lookups[domain.decode()] = ip.decode()
self.reactor.lookups[ip.decode()] = ip.decode()

self.ip_whitelist = IPSet([self.allowed_ip.decode()])
self.ip_blacklist = IPSet(["5.0.0.0/8"])

def test_reactor(self):
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
agent = Agent(
BlacklistingReactorWrapper(
self.reactor,
ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
),
)

# The unsafe domains and IPs should be rejected.
for domain in (self.unsafe_domain, self.unsafe_ip):
self.failureResultOf(
agent.request(b"GET", b"http://" + domain), DNSLookupError
)

# The safe domains IPs should be accepted.
for domain in (
self.safe_domain,
self.allowed_domain,
self.safe_ip,
self.allowed_ip,
):
d = agent.request(b"GET", b"http://" + domain)

# Grab the latest TCP connection.
(
host,
port,
client_factory,
_timeout,
_bindAddress,
) = self.reactor.tcpClients[-1]

# Make the connection and pump data through it.
client = client_factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
)

response = self.successResultOf(d)
self.assertEqual(response.code, 200)

def test_agent(self):
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
agent = BlacklistingAgentWrapper(
Agent(self.reactor),
ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
)

# The unsafe IPs should be rejected.
self.failureResultOf(
agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
)

# The safe and unsafe domains and safe IPs should be accepted.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the unsafe domain accepted? Shouldn't it resolve to the unsafe IP and then be blocked?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The agent doesn't do any resolution, it only attempts to block access when IPs are directly accessed. The reactor ensures that resolved IPs can be blocked effectively. This is tested above. You should use both of them in tandem to ensure all cases are blocked.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, and the docstring of BlacklistingAgentWrapper does indeed say as much. Ah well, thank you!

for domain in (
self.safe_domain,
self.unsafe_domain,
self.allowed_domain,
self.safe_ip,
self.allowed_ip,
):
d = agent.request(b"GET", b"http://" + domain)

# Grab the latest TCP connection.
(
host,
port,
client_factory,
_timeout,
_bindAddress,
) = self.reactor.tcpClients[-1]

# Make the connection and pump data through it.
client = client_factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
)

response = self.successResultOf(d)
self.assertEqual(response.code, 200)