Skip to content

Commit 64e8bf1

Browse files
authored
Merge branch 'main' into tde/golden_print_regression
2 parents 9b62acb + e35495d commit 64e8bf1

File tree

33 files changed

+1739
-848
lines changed

33 files changed

+1739
-848
lines changed

.github/workflows/check_api_backwards_compatibility_workflow.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
# Default baseline for automatic PR checks
6767
# Can be: branch name (e.g., 'main'), commit hash, or tag
6868
# Will be resolved to commit hash during execution
69-
DEFAULT_BASELINE: 'f7fb5ecbe218672719053fa304d91767ce30ffa1'
69+
DEFAULT_BASELINE: '29a810e644d079a91955c0ab98afb0798b10ab52'
7070
# Tag pattern for auto-detection (e.g., 'core_r*', 'core_v*')
7171
TAG_PATTERN: 'core_v*'
7272
# Tag regex filter (e.g., '^core_v[0-9]+\.[0-9]+\.[0-9]+$' for stable versions only)

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,8 @@ def __init__(
302302
assert (
303303
mamba_ssm_states_shape is not None
304304
), "`mamba_ssm_states_shape` must be specified for hybrid models"
305-
assert (
306-
not use_cuda_graphs_for_non_decode_steps
305+
assert not (
306+
num_cuda_graphs is not None and use_cuda_graphs_for_non_decode_steps
307307
), "Non-decode CUDA graphs not yet supported for hybrid models"
308308

309309
# For hybrid models, the layer map converts the global layer index to the
@@ -1079,6 +1079,7 @@ def initialize_attention_state(
10791079
self.padded_active_token_count = min(
10801080
self.padded_active_token_count, self.max_active_requests
10811081
)
1082+
self.padding_slice = slice(active_token_count, self.padded_active_token_count)
10821083

10831084
# How are we calculating the padded active request count?
10841085
# Case 1: Using cuda graphs:
@@ -1427,6 +1428,14 @@ def _swap_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens):
14271428
if self.is_hybrid_model:
14281429
tensor_swap(self.mamba_metadata.request_to_mamba_state_idx, src_idxs, dst_idxs)
14291430

1431+
def get_index_of_chunked_prefill_request(self) -> int:
1432+
"""Get the index of the chunked prefill request in the context.
1433+
1434+
Return:
1435+
(int) Index of the chunked prefill request, or -1 if none exists.
1436+
"""
1437+
return torch.where(self.request_ids == self.chunked_prefill_request_id)[0][0]
1438+
14301439
# TODO: see if we can compile this function
14311440
def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> Tensor:
14321441
"""Update context state after calling engine.step().
@@ -1583,8 +1592,9 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T
15831592

15841593
if self.chunked_prefill_request_id != -1:
15851594
# find the id in request_ids that is the chunked_prefill_request_id. Only one request should be chunked.
1586-
pos = torch.where(self.request_ids == self.chunked_prefill_request_id)[0][0]
1587-
active_requests_requiring_new_block[pos] = 0 # chunked prefill should not be paused
1595+
active_requests_requiring_new_block[self.get_index_of_chunked_prefill_request()] = (
1596+
0 # chunked prefill should not be paused
1597+
)
15881598

15891599
active_requests_requiring_new_block_count = (
15901600
(active_requests_requiring_new_block == 1).sum().item()
@@ -1651,11 +1661,10 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T
16511661
active_request_count += resume_request_count
16521662
assert active_request_count > 0, "active_request_count == %d." % active_request_count
16531663

1654-
# finally, swap the chunked prefill to the end of the active requests to obey the invariant
1664+
# finally, swap the chunked prefill to the end of the active requests to obey the invariance
16551665
if self.chunked_prefill_request_id != -1:
1656-
pos = torch.where(self.request_ids == self.chunked_prefill_request_id)[0][0]
16571666
self._swap_book_keeping_tensors(
1658-
src_idxs=torch.tensor([pos]),
1667+
src_idxs=torch.tensor([self.get_index_of_chunked_prefill_request()]),
16591668
dst_idxs=torch.tensor([active_request_count + self.paused_request_count - 1]),
16601669
next_tokens=next_tokens,
16611670
)

0 commit comments

Comments
 (0)