@@ -302,8 +302,8 @@ def __init__(
302302 assert (
303303 mamba_ssm_states_shape is not None
304304 ), "`mamba_ssm_states_shape` must be specified for hybrid models"
305- assert (
306- not use_cuda_graphs_for_non_decode_steps
305+ assert not (
306+ num_cuda_graphs is not None and use_cuda_graphs_for_non_decode_steps
307307 ), "Non-decode CUDA graphs not yet supported for hybrid models"
308308
309309 # For hybrid models, the layer map converts the global layer index to the
@@ -1079,6 +1079,7 @@ def initialize_attention_state(
10791079 self .padded_active_token_count = min (
10801080 self .padded_active_token_count , self .max_active_requests
10811081 )
1082+ self .padding_slice = slice (active_token_count , self .padded_active_token_count )
10821083
10831084 # How are we calculating the padded active request count?
10841085 # Case 1: Using cuda graphs:
@@ -1427,6 +1428,14 @@ def _swap_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens):
14271428 if self .is_hybrid_model :
14281429 tensor_swap (self .mamba_metadata .request_to_mamba_state_idx , src_idxs , dst_idxs )
14291430
1431+ def get_index_of_chunked_prefill_request (self ) -> int :
1432+ """Get the index of the chunked prefill request in the context.
1433+
1434+ Return:
1435+ (int) Index of the chunked prefill request, or -1 if none exists.
1436+ """
1437+ return torch .where (self .request_ids == self .chunked_prefill_request_id )[0 ][0 ]
1438+
14301439 # TODO: see if we can compile this function
14311440 def update_requests (self , active_requests_mask : Tensor , new_tokens : Tensor ) -> Tensor :
14321441 """Update context state after calling engine.step().
@@ -1583,8 +1592,9 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T
15831592
15841593 if self .chunked_prefill_request_id != - 1 :
15851594 # find the id in request_ids that is the chunked_prefill_request_id. Only one request should be chunked.
1586- pos = torch .where (self .request_ids == self .chunked_prefill_request_id )[0 ][0 ]
1587- active_requests_requiring_new_block [pos ] = 0 # chunked prefill should not be paused
1595+ active_requests_requiring_new_block [self .get_index_of_chunked_prefill_request ()] = (
1596+ 0 # chunked prefill should not be paused
1597+ )
15881598
15891599 active_requests_requiring_new_block_count = (
15901600 (active_requests_requiring_new_block == 1 ).sum ().item ()
@@ -1651,11 +1661,10 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T
16511661 active_request_count += resume_request_count
16521662 assert active_request_count > 0 , "active_request_count == %d." % active_request_count
16531663
1654- # finally, swap the chunked prefill to the end of the active requests to obey the invariant
1664+ # finally, swap the chunked prefill to the end of the active requests to obey the invariance
16551665 if self .chunked_prefill_request_id != - 1 :
1656- pos = torch .where (self .request_ids == self .chunked_prefill_request_id )[0 ][0 ]
16571666 self ._swap_book_keeping_tensors (
1658- src_idxs = torch .tensor ([pos ]),
1667+ src_idxs = torch .tensor ([self . get_index_of_chunked_prefill_request () ]),
16591668 dst_idxs = torch .tensor ([active_request_count + self .paused_request_count - 1 ]),
16601669 next_tokens = next_tokens ,
16611670 )
0 commit comments