|
16 | 16 |
|
17 | 17 | from mock import Mock
|
18 | 18 |
|
| 19 | +from netaddr import IPSet |
| 20 | + |
| 21 | +from twisted.internet.error import DNSLookupError |
19 | 22 | from twisted.python.failure import Failure
|
20 |
| -from twisted.web.client import ResponseDone |
| 23 | +from twisted.test.proto_helpers import AccumulatingProtocol |
| 24 | +from twisted.web.client import Agent, ResponseDone |
21 | 25 | from twisted.web.iweb import UNKNOWN_LENGTH
|
22 | 26 |
|
23 |
| -from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size |
| 27 | +from synapse.api.errors import SynapseError |
| 28 | +from synapse.http.client import ( |
| 29 | + BlacklistingAgentWrapper, |
| 30 | + BlacklistingReactorWrapper, |
| 31 | + BodyExceededMaxSize, |
| 32 | + read_body_with_max_size, |
| 33 | +) |
24 | 34 |
|
| 35 | +from tests.server import FakeTransport, get_clock |
25 | 36 | from tests.unittest import TestCase
|
26 | 37 |
|
27 | 38 |
|
@@ -119,3 +130,114 @@ def test_content_length(self):
|
119 | 130 |
|
120 | 131 | # The data is never consumed.
|
121 | 132 | self.assertEqual(result.getvalue(), b"")
|
| 133 | + |
| 134 | + |
| 135 | +class BlacklistingAgentTest(TestCase): |
| 136 | + def setUp(self): |
| 137 | + self.reactor, self.clock = get_clock() |
| 138 | + |
| 139 | + self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4" |
| 140 | + self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8" |
| 141 | + self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1" |
| 142 | + |
| 143 | + # Configure the reactor's DNS resolver. |
| 144 | + for (domain, ip) in ( |
| 145 | + (self.safe_domain, self.safe_ip), |
| 146 | + (self.unsafe_domain, self.unsafe_ip), |
| 147 | + (self.allowed_domain, self.allowed_ip), |
| 148 | + ): |
| 149 | + self.reactor.lookups[domain.decode()] = ip.decode() |
| 150 | + self.reactor.lookups[ip.decode()] = ip.decode() |
| 151 | + |
| 152 | + self.ip_whitelist = IPSet([self.allowed_ip.decode()]) |
| 153 | + self.ip_blacklist = IPSet(["5.0.0.0/8"]) |
| 154 | + |
| 155 | + def test_reactor(self): |
| 156 | + """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs.""" |
| 157 | + agent = Agent( |
| 158 | + BlacklistingReactorWrapper( |
| 159 | + self.reactor, |
| 160 | + ip_whitelist=self.ip_whitelist, |
| 161 | + ip_blacklist=self.ip_blacklist, |
| 162 | + ), |
| 163 | + ) |
| 164 | + |
| 165 | + # The unsafe domains and IPs should be rejected. |
| 166 | + for domain in (self.unsafe_domain, self.unsafe_ip): |
| 167 | + self.failureResultOf( |
| 168 | + agent.request(b"GET", b"http://" + domain), DNSLookupError |
| 169 | + ) |
| 170 | + |
| 171 | + # The safe domains IPs should be accepted. |
| 172 | + for domain in ( |
| 173 | + self.safe_domain, |
| 174 | + self.allowed_domain, |
| 175 | + self.safe_ip, |
| 176 | + self.allowed_ip, |
| 177 | + ): |
| 178 | + d = agent.request(b"GET", b"http://" + domain) |
| 179 | + |
| 180 | + # Grab the latest TCP connection. |
| 181 | + ( |
| 182 | + host, |
| 183 | + port, |
| 184 | + client_factory, |
| 185 | + _timeout, |
| 186 | + _bindAddress, |
| 187 | + ) = self.reactor.tcpClients[-1] |
| 188 | + |
| 189 | + # Make the connection and pump data through it. |
| 190 | + client = client_factory.buildProtocol(None) |
| 191 | + server = AccumulatingProtocol() |
| 192 | + server.makeConnection(FakeTransport(client, self.reactor)) |
| 193 | + client.makeConnection(FakeTransport(server, self.reactor)) |
| 194 | + client.dataReceived( |
| 195 | + b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n" |
| 196 | + ) |
| 197 | + |
| 198 | + response = self.successResultOf(d) |
| 199 | + self.assertEqual(response.code, 200) |
| 200 | + |
| 201 | + def test_agent(self): |
| 202 | + """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs.""" |
| 203 | + agent = BlacklistingAgentWrapper( |
| 204 | + Agent(self.reactor), |
| 205 | + ip_whitelist=self.ip_whitelist, |
| 206 | + ip_blacklist=self.ip_blacklist, |
| 207 | + ) |
| 208 | + |
| 209 | + # The unsafe IPs should be rejected. |
| 210 | + self.failureResultOf( |
| 211 | + agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError |
| 212 | + ) |
| 213 | + |
| 214 | + # The safe and unsafe domains and safe IPs should be accepted. |
| 215 | + for domain in ( |
| 216 | + self.safe_domain, |
| 217 | + self.unsafe_domain, |
| 218 | + self.allowed_domain, |
| 219 | + self.safe_ip, |
| 220 | + self.allowed_ip, |
| 221 | + ): |
| 222 | + d = agent.request(b"GET", b"http://" + domain) |
| 223 | + |
| 224 | + # Grab the latest TCP connection. |
| 225 | + ( |
| 226 | + host, |
| 227 | + port, |
| 228 | + client_factory, |
| 229 | + _timeout, |
| 230 | + _bindAddress, |
| 231 | + ) = self.reactor.tcpClients[-1] |
| 232 | + |
| 233 | + # Make the connection and pump data through it. |
| 234 | + client = client_factory.buildProtocol(None) |
| 235 | + server = AccumulatingProtocol() |
| 236 | + server.makeConnection(FakeTransport(client, self.reactor)) |
| 237 | + client.makeConnection(FakeTransport(server, self.reactor)) |
| 238 | + client.dataReceived( |
| 239 | + b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n" |
| 240 | + ) |
| 241 | + |
| 242 | + response = self.successResultOf(d) |
| 243 | + self.assertEqual(response.code, 200) |
0 commit comments