Skip to content

Commit caac838

Browse files
authored
Merge branch 'main' into tde/golden_print_regression
2 parents 64e8bf1 + 8954e04 commit caac838

File tree

21 files changed

+1145
-519
lines changed

21 files changed

+1145
-519
lines changed

.github/workflows/check_api_backwards_compatibility_workflow.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
# Default baseline for automatic PR checks
6767
# Can be: branch name (e.g., 'main'), commit hash, or tag
6868
# Will be resolved to commit hash during execution
69-
DEFAULT_BASELINE: '29a810e644d079a91955c0ab98afb0798b10ab52'
69+
DEFAULT_BASELINE: '53bbf7a23d7194de1fbe991ba120a0d49bd5b097'
7070
# Tag pattern for auto-detection (e.g., 'core_r*', 'core_v*')
7171
TAG_PATTERN: 'core_v*'
7272
# Tag regex filter (e.g., '^core_v[0-9]+\.[0-9]+\.[0-9]+$' for stable versions only)

examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ async def main(
4343
engine: DynamicInferenceEngine,
4444
requests: List[Request],
4545
port: int,
46-
mp_port: int,
4746
sampling_params: SamplingParams | None = None,
4847
):
4948
if sampling_params is not None:
@@ -58,7 +57,6 @@ async def main(
5857

5958
await engine.start_listening_to_data_parallel_coordinator(
6059
inference_coordinator_port=port,
61-
inference_mp_coordinator_port=mp_port,
6260
launch_inference_coordinator=True,
6361
verbose=True,
6462
)
@@ -258,6 +256,5 @@ async def main(
258256
engine,
259257
requests,
260258
args.inference_coordinator_port,
261-
args.inference_mp_coordinator_port
262259
)
263260
)

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,6 @@ def initialize_attention_state(
10791079
self.padded_active_token_count = min(
10801080
self.padded_active_token_count, self.max_active_requests
10811081
)
1082-
self.padding_slice = slice(active_token_count, self.padded_active_token_count)
10831082

10841083
# How are we calculating the padded active request count?
10851084
# Case 1: Using cuda graphs:

megatron/core/inference/data_parallel_inference_coordinator.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def __init__(self, inference_coordinator_port: int, data_parallel_size: int):
109109
self.identities_of_data_parallel_ranks.append(identity)
110110
logging.info("Inference Coordinator: Connected with data parallel ranks...")
111111
self.data_parallel_rank_iterator = cycle(self.identities_of_data_parallel_ranks)
112+
self.data_parallel_pause_acks = set()
113+
self.data_parallel_stop_acks = set()
112114

113115
self.request_id_to_client_id = {}
114116
self.request_id_to_client_request_id = {}
@@ -151,7 +153,7 @@ def start(self):
151153
# print(f"New client connected: {sender_identity}")
152154
known_clients.add(sender_identity)
153155
self.router_socket.send_multipart(
154-
[sender_identity, msgpack.packb([Headers.ACK.value], use_bin_type=True)]
156+
[sender_identity, msgpack.packb([Headers.CONNECT_ACK.value], use_bin_type=True)]
155157
)
156158

157159
elif header == Headers.SUBMIT_REQUEST:
@@ -208,6 +210,50 @@ def start(self):
208210
self.router_socket.send_multipart(
209211
[data_parallel_rank_id, msgpack.packb([header.value], use_bin_type=True)]
210212
)
213+
if header == Headers.UNPAUSE:
214+
self.data_parallel_pause_acks = set()
215+
elif header == Headers.PAUSE_ACK:
216+
# control signal ack from the engine
217+
assert sender_identity in self.identities_of_data_parallel_ranks
218+
assert sender_identity not in self.data_parallel_pause_acks
219+
self.data_parallel_pause_acks.add(sender_identity)
220+
# route to all clients only once we have gotten an ack from all data parallel ranks
221+
if len(self.data_parallel_pause_acks) == self.data_parallel_size:
222+
for client_id in known_clients:
223+
self.router_socket.send_multipart(
224+
[
225+
client_id,
226+
msgpack.packb([header.value, sender_identity], use_bin_type=True),
227+
]
228+
)
229+
for data_parallel_rank_id in self.identities_of_data_parallel_ranks:
230+
self.router_socket.send_multipart(
231+
[
232+
data_parallel_rank_id,
233+
msgpack.packb([Headers.PAUSE_ACK.value], use_bin_type=True),
234+
]
235+
)
236+
elif header == Headers.STOP_ACK:
237+
# control signal ack from the engine
238+
assert sender_identity in self.identities_of_data_parallel_ranks
239+
assert sender_identity not in self.data_parallel_stop_acks
240+
self.data_parallel_stop_acks.add(sender_identity)
241+
# route to all clients only once we have gotten an ack from all data parallel ranks
242+
if len(self.data_parallel_stop_acks) == self.data_parallel_size:
243+
for client_id in known_clients:
244+
self.router_socket.send_multipart(
245+
[
246+
client_id,
247+
msgpack.packb([header.value, sender_identity], use_bin_type=True),
248+
]
249+
)
250+
for data_parallel_rank_id in self.identities_of_data_parallel_ranks:
251+
self.router_socket.send_multipart(
252+
[
253+
data_parallel_rank_id,
254+
msgpack.packb([Headers.STOP_ACK.value], use_bin_type=True),
255+
]
256+
)
211257
elif header == Headers.ENGINE_REPLY:
212258
# This is the output of a single engine step on some data parallel rank.
213259
assert sender_identity in self.identities_of_data_parallel_ranks
@@ -224,7 +270,7 @@ def start(self):
224270
[
225271
client_identity,
226272
msgpack.packb(
227-
[client_request_identity, finished_request_record],
273+
[header.value, client_request_identity, finished_request_record],
228274
use_bin_type=True,
229275
),
230276
]

0 commit comments

Comments
 (0)