@@ -109,6 +109,8 @@ def __init__(self, inference_coordinator_port: int, data_parallel_size: int):
109109 self .identities_of_data_parallel_ranks .append (identity )
110110 logging .info ("Inference Coordinator: Connected with data parallel ranks..." )
111111 self .data_parallel_rank_iterator = cycle (self .identities_of_data_parallel_ranks )
112+ self .data_parallel_pause_acks = set ()
113+ self .data_parallel_stop_acks = set ()
112114
113115 self .request_id_to_client_id = {}
114116 self .request_id_to_client_request_id = {}
@@ -151,7 +153,7 @@ def start(self):
151153 # print(f"New client connected: {sender_identity}")
152154 known_clients .add (sender_identity )
153155 self .router_socket .send_multipart (
154- [sender_identity , msgpack .packb ([Headers .ACK .value ], use_bin_type = True )]
156+ [sender_identity , msgpack .packb ([Headers .CONNECT_ACK .value ], use_bin_type = True )]
155157 )
156158
157159 elif header == Headers .SUBMIT_REQUEST :
@@ -208,6 +210,50 @@ def start(self):
208210 self .router_socket .send_multipart (
209211 [data_parallel_rank_id , msgpack .packb ([header .value ], use_bin_type = True )]
210212 )
213+ if header == Headers .UNPAUSE :
214+ self .data_parallel_pause_acks = set ()
215+ elif header == Headers .PAUSE_ACK :
216+ # control signal ack from the engine
217+ assert sender_identity in self .identities_of_data_parallel_ranks
218+ assert sender_identity not in self .data_parallel_pause_acks
219+ self .data_parallel_pause_acks .add (sender_identity )
220+ # route to all clients only once we have gotten an ack from all data parallel ranks
221+ if len (self .data_parallel_pause_acks ) == self .data_parallel_size :
222+ for client_id in known_clients :
223+ self .router_socket .send_multipart (
224+ [
225+ client_id ,
226+ msgpack .packb ([header .value , sender_identity ], use_bin_type = True ),
227+ ]
228+ )
229+ for data_parallel_rank_id in self .identities_of_data_parallel_ranks :
230+ self .router_socket .send_multipart (
231+ [
232+ data_parallel_rank_id ,
233+ msgpack .packb ([Headers .PAUSE_ACK .value ], use_bin_type = True ),
234+ ]
235+ )
236+ elif header == Headers .STOP_ACK :
237+ # control signal ack from the engine
238+ assert sender_identity in self .identities_of_data_parallel_ranks
239+ assert sender_identity not in self .data_parallel_stop_acks
240+ self .data_parallel_stop_acks .add (sender_identity )
241+ # route to all clients only once we have gotten an ack from all data parallel ranks
242+ if len (self .data_parallel_stop_acks ) == self .data_parallel_size :
243+ for client_id in known_clients :
244+ self .router_socket .send_multipart (
245+ [
246+ client_id ,
247+ msgpack .packb ([header .value , sender_identity ], use_bin_type = True ),
248+ ]
249+ )
250+ for data_parallel_rank_id in self .identities_of_data_parallel_ranks :
251+ self .router_socket .send_multipart (
252+ [
253+ data_parallel_rank_id ,
254+ msgpack .packb ([Headers .STOP_ACK .value ], use_bin_type = True ),
255+ ]
256+ )
211257 elif header == Headers .ENGINE_REPLY :
212258 # This is the output of a single engine step on some data parallel rank.
213259 assert sender_identity in self .identities_of_data_parallel_ranks
@@ -224,7 +270,7 @@ def start(self):
224270 [
225271 client_identity ,
226272 msgpack .packb (
227- [client_request_identity , finished_request_record ],
273+ [header . value , client_request_identity , finished_request_record ],
228274 use_bin_type = True ,
229275 ),
230276 ]
0 commit comments