Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit e55bd0e

Browse files
authored
Add tests for blacklisting reactor/agent. (#9563)
1 parent 70d1b6a commit e55bd0e

File tree

3 files changed

+139
-14
lines changed

3 files changed

+139
-14
lines changed

changelog.d/9563.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper.

synapse/http/client.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from OpenSSL import SSL
4040
from OpenSSL.SSL import VERIFY_NONE
4141
from twisted.internet import defer, error as twisted_error, protocol, ssl
42+
from twisted.internet.address import IPv4Address, IPv6Address
4243
from twisted.internet.interfaces import (
4344
IAddress,
4445
IHostResolution,
@@ -151,16 +152,17 @@ def __init__(
151152
def resolveHostName(
152153
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
153154
) -> IResolutionReceiver:
154-
155-
r = recv()
156155
addresses = [] # type: List[IAddress]
157156

158157
def _callback() -> None:
159-
r.resolutionBegan(None)
160-
161158
has_bad_ip = False
162-
for i in addresses:
163-
ip_address = IPAddress(i.host)
159+
for address in addresses:
160+
# We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
161+
# should go through this path.
162+
if not isinstance(address, (IPv4Address, IPv6Address)):
163+
continue
164+
165+
ip_address = IPAddress(address.host)
164166

165167
if check_against_blacklist(
166168
ip_address, self._ip_whitelist, self._ip_blacklist
@@ -175,15 +177,15 @@ def _callback() -> None:
175177
# request, but all we can really do from here is claim that there were no
176178
# valid results.
177179
if not has_bad_ip:
178-
for i in addresses:
179-
r.addressResolved(i)
180-
r.resolutionComplete()
180+
for address in addresses:
181+
recv.addressResolved(address)
182+
recv.resolutionComplete()
181183

182184
@provider(IResolutionReceiver)
183185
class EndpointReceiver:
184186
@staticmethod
185187
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
186-
pass
188+
recv.resolutionBegan(resolutionInProgress)
187189

188190
@staticmethod
189191
def addressResolved(address: IAddress) -> None:
@@ -197,7 +199,7 @@ def resolutionComplete() -> None:
197199
EndpointReceiver, hostname, portNumber=portNumber
198200
)
199201

200-
return r
202+
return recv
201203

202204

203205
@implementer(ISynapseReactor)
@@ -346,7 +348,7 @@ def __init__(
346348
contextFactory=self.hs.get_http_client_context_factory(),
347349
pool=pool,
348350
use_proxy=use_proxy,
349-
)
351+
) # type: IAgent
350352

351353
if self._ip_blacklist:
352354
# If we have an IP blacklist, we then install the blacklisting Agent

tests/http/test_client.py

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,23 @@
1616

1717
from mock import Mock
1818

19+
from netaddr import IPSet
20+
21+
from twisted.internet.error import DNSLookupError
1922
from twisted.python.failure import Failure
20-
from twisted.web.client import ResponseDone
23+
from twisted.test.proto_helpers import AccumulatingProtocol
24+
from twisted.web.client import Agent, ResponseDone
2125
from twisted.web.iweb import UNKNOWN_LENGTH
2226

23-
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
27+
from synapse.api.errors import SynapseError
28+
from synapse.http.client import (
29+
BlacklistingAgentWrapper,
30+
BlacklistingReactorWrapper,
31+
BodyExceededMaxSize,
32+
read_body_with_max_size,
33+
)
2434

35+
from tests.server import FakeTransport, get_clock
2536
from tests.unittest import TestCase
2637

2738

@@ -119,3 +130,114 @@ def test_content_length(self):
119130

120131
# The data is never consumed.
121132
self.assertEqual(result.getvalue(), b"")
133+
134+
135+
class BlacklistingAgentTest(TestCase):
136+
def setUp(self):
137+
self.reactor, self.clock = get_clock()
138+
139+
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
140+
self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
141+
self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
142+
143+
# Configure the reactor's DNS resolver.
144+
for (domain, ip) in (
145+
(self.safe_domain, self.safe_ip),
146+
(self.unsafe_domain, self.unsafe_ip),
147+
(self.allowed_domain, self.allowed_ip),
148+
):
149+
self.reactor.lookups[domain.decode()] = ip.decode()
150+
self.reactor.lookups[ip.decode()] = ip.decode()
151+
152+
self.ip_whitelist = IPSet([self.allowed_ip.decode()])
153+
self.ip_blacklist = IPSet(["5.0.0.0/8"])
154+
155+
def test_reactor(self):
156+
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
157+
agent = Agent(
158+
BlacklistingReactorWrapper(
159+
self.reactor,
160+
ip_whitelist=self.ip_whitelist,
161+
ip_blacklist=self.ip_blacklist,
162+
),
163+
)
164+
165+
# The unsafe domains and IPs should be rejected.
166+
for domain in (self.unsafe_domain, self.unsafe_ip):
167+
self.failureResultOf(
168+
agent.request(b"GET", b"http://" + domain), DNSLookupError
169+
)
170+
171+
# The safe domains IPs should be accepted.
172+
for domain in (
173+
self.safe_domain,
174+
self.allowed_domain,
175+
self.safe_ip,
176+
self.allowed_ip,
177+
):
178+
d = agent.request(b"GET", b"http://" + domain)
179+
180+
# Grab the latest TCP connection.
181+
(
182+
host,
183+
port,
184+
client_factory,
185+
_timeout,
186+
_bindAddress,
187+
) = self.reactor.tcpClients[-1]
188+
189+
# Make the connection and pump data through it.
190+
client = client_factory.buildProtocol(None)
191+
server = AccumulatingProtocol()
192+
server.makeConnection(FakeTransport(client, self.reactor))
193+
client.makeConnection(FakeTransport(server, self.reactor))
194+
client.dataReceived(
195+
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
196+
)
197+
198+
response = self.successResultOf(d)
199+
self.assertEqual(response.code, 200)
200+
201+
def test_agent(self):
202+
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
203+
agent = BlacklistingAgentWrapper(
204+
Agent(self.reactor),
205+
ip_whitelist=self.ip_whitelist,
206+
ip_blacklist=self.ip_blacklist,
207+
)
208+
209+
# The unsafe IPs should be rejected.
210+
self.failureResultOf(
211+
agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
212+
)
213+
214+
# The safe and unsafe domains and safe IPs should be accepted.
215+
for domain in (
216+
self.safe_domain,
217+
self.unsafe_domain,
218+
self.allowed_domain,
219+
self.safe_ip,
220+
self.allowed_ip,
221+
):
222+
d = agent.request(b"GET", b"http://" + domain)
223+
224+
# Grab the latest TCP connection.
225+
(
226+
host,
227+
port,
228+
client_factory,
229+
_timeout,
230+
_bindAddress,
231+
) = self.reactor.tcpClients[-1]
232+
233+
# Make the connection and pump data through it.
234+
client = client_factory.buildProtocol(None)
235+
server = AccumulatingProtocol()
236+
server.makeConnection(FakeTransport(client, self.reactor))
237+
client.makeConnection(FakeTransport(server, self.reactor))
238+
client.dataReceived(
239+
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
240+
)
241+
242+
response = self.successResultOf(d)
243+
self.assertEqual(response.code, 200)

0 commit comments

Comments
 (0)