22
22
from twisted .internet .task import LoopingCall
23
23
from twisted .web .http import HTTPChannel
24
24
from twisted .web .resource import Resource
25
+ from twisted .web .server import Request , Site
25
26
26
27
from synapse .app .generic_worker import (
27
28
GenericWorkerReplicationHandler ,
32
33
from synapse .replication .http import ReplicationRestResource
33
34
from synapse .replication .tcp .handler import ReplicationCommandHandler
34
35
from synapse .replication .tcp .protocol import ClientReplicationStreamProtocol
35
- from synapse .replication .tcp .resource import ReplicationStreamProtocolFactory
36
+ from synapse .replication .tcp .resource import (
37
+ ReplicationStreamProtocolFactory ,
38
+ ServerReplicationStreamProtocol ,
39
+ )
36
40
from synapse .server import HomeServer
37
41
from synapse .util import Clock
38
42
@@ -59,7 +63,9 @@ def prepare(self, reactor, clock, hs):
59
63
# build a replication server
60
64
server_factory = ReplicationStreamProtocolFactory (hs )
61
65
self .streamer = hs .get_replication_streamer ()
62
- self .server = server_factory .buildProtocol (None )
66
+ self .server = server_factory .buildProtocol (
67
+ None
68
+ ) # type: ServerReplicationStreamProtocol
63
69
64
70
# Make a new HomeServer object for the worker
65
71
self .reactor .lookups ["testserv" ] = "1.2.3.4"
@@ -155,9 +161,7 @@ def handle_http_replication_attempt(self) -> SynapseRequest:
155
161
request_factory = OneShotRequestFactory ()
156
162
157
163
# Set up the server side protocol
158
- channel = _PushHTTPChannel (self .reactor )
159
- channel .requestFactory = request_factory
160
- channel .site = self .site
164
+ channel = _PushHTTPChannel (self .reactor , request_factory , self .site )
161
165
162
166
# Connect client to server and vice versa.
163
167
client_to_server_transport = FakeTransport (
@@ -188,8 +192,9 @@ def assert_request_is_get_repl_stream_updates(
188
192
fetching updates for given stream.
189
193
"""
190
194
195
+ path = request .path # type: bytes # type: ignore
191
196
self .assertRegex (
192
- request . path ,
197
+ path ,
193
198
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
194
199
% (stream_name .encode ("ascii" ),),
195
200
)
@@ -390,9 +395,7 @@ def _handle_http_replication_attempt(self, hs, repl_port):
390
395
request_factory = OneShotRequestFactory ()
391
396
392
397
# Set up the server side protocol
393
- channel = _PushHTTPChannel (self .reactor )
394
- channel .requestFactory = request_factory
395
- channel .site = self ._hs_to_site [hs ]
398
+ channel = _PushHTTPChannel (self .reactor , request_factory , self ._hs_to_site [hs ])
396
399
397
400
# Connect client to server and vice versa.
398
401
client_to_server_transport = FakeTransport (
@@ -475,9 +478,13 @@ class _PushHTTPChannel(HTTPChannel):
475
478
makes it very hard to test.
476
479
"""
477
480
478
- def __init__ (self , reactor : IReactorTime ):
481
+ def __init__ (
482
+ self , reactor : IReactorTime , request_factory : Callable [..., Request ], site : Site
483
+ ):
479
484
super ().__init__ ()
480
485
self .reactor = reactor
486
+ self .requestFactory = request_factory
487
+ self .site = site
481
488
482
489
self ._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
483
490
0 commit comments