@@ -31,6 +31,7 @@ __global__ void free_and_dispatch_block(bool *stop_flags,
3131 int *used_list_len,
3232 int *free_list,
3333 int *free_list_len,
34+ int64_t *first_token_ids,
3435 const int bsz,
3536 const int block_size,
3637 const int block_num_per_seq,
@@ -43,6 +44,7 @@ __global__ void free_and_dispatch_block(bool *stop_flags,
4344 int *block_table_now = block_tables + tid * block_num_per_seq;
4445 if (stop_flags[tid] && !is_block_step[tid]) {
4546 // 回收block块
47+ first_token_ids[tid] = -1 ;
4648 const int encoder_block_len = encoder_block_lens[tid];
4749 const int decoder_used_len = used_list_len[tid];
4850 if (decoder_used_len > 0 ) {
@@ -166,11 +168,11 @@ __global__ void recover_block(int *recover_block_list, // [bsz]
166168 int *encoder_block_lens,
167169 int *used_list_len,
168170 const int64_t *next_tokens,
171+ const int64_t *first_token_ids,
169172 const int bsz,
170173 const int block_num_per_seq,
171174 const int length,
172- const int pre_id_length,
173- const int first_token_id) {
175+ const int pre_id_length) {
174176 const int bid = blockIdx .x ;
175177 const int tid = threadIdx .x ;
176178 __shared__ int ori_free_list_len;
@@ -189,7 +191,8 @@ __global__ void recover_block(int *recover_block_list, // [bsz]
189191 seq_lens_encoder[recover_id] = seq_len;
190192 stop_flags[recover_id] = false ;
191193 input_ids_now[ori_seq_len_encoder + step_idx_now - 1 ] = next_tokens[recover_id]; // next tokens
192- input_ids_now[0 ] = first_token_id; // set first prompt token
194+ input_ids_now[0 ] =
195+ first_token_ids[recover_id]; // set first prompt token
193196 const int ori_free_list_len_tid0 = atomicSub (free_list_len, decoder_used_len);
194197 ori_free_list_len = ori_free_list_len_tid0;
195198#ifdef DEBUG_STEP
@@ -234,9 +237,9 @@ void StepPaddle(const paddle::Tensor& stop_flags,
234237 const paddle::Tensor& pre_ids,
235238 const paddle::Tensor& step_idx,
236239 const paddle::Tensor& next_tokens,
240+ const paddle::Tensor &first_token_ids,
237241 const int block_size,
238242 const int encoder_decoder_block_num,
239- const int64_t first_token_id,
240243 const int speculate_step_token_num) {
241244 auto cu_stream = seq_lens_this_time.stream ();
242245 const int bsz = seq_lens_this_time.shape ()[0 ];
@@ -264,6 +267,7 @@ void StepPaddle(const paddle::Tensor& stop_flags,
264267 const_cast <int *>(used_list_len.data <int >()),
265268 const_cast <int *>(free_list.data <int >()),
266269 const_cast <int *>(free_list_len.data <int >()),
270+ const_cast <int64_t *>(first_token_ids.data <int64_t >()),
267271 bsz,
268272 block_size,
269273 block_num_per_seq,
@@ -300,11 +304,11 @@ void StepPaddle(const paddle::Tensor& stop_flags,
300304 const_cast <int *>(encoder_block_lens.data <int >()),
301305 const_cast <int *>(used_list_len.data <int >()),
302306 next_tokens.data <int64_t >(),
307+ first_token_ids.data <int64_t >(),
303308 bsz,
304309 block_num_per_seq,
305310 length,
306- pre_id_length,
307- first_token_id
311+ pre_id_length
308312 );
309313#ifdef DEBUG_STEP
310314#ifdef PADDLE_WITH_HIP
@@ -337,10 +341,10 @@ PD_BUILD_OP(step_paddle)
337341 " input_ids" ,
338342 " pre_ids" ,
339343 " step_idx" ,
340- " next_tokens" })
344+ " next_tokens" ,
345+ " first_token_ids" ,})
341346 .Attrs({" block_size: int" ,
342347 " encoder_decoder_block_num: int" ,
343- " first_token_id: int64_t" ,
344348 " speculate_step_token_num: int" })
345349 .Outputs({" stop_flags_out" ,
346350 " seq_lens_this_time_out" ,
@@ -358,7 +362,8 @@ PD_BUILD_OP(step_paddle)
358362 " used_list_len_out" ,
359363 " free_list_out" ,
360364 " free_list_len_out" ,
361- " input_ids_out" })
365+ " input_ids_out" ,
366+ " first_token_ids_out" ,})
362367 .SetInplaceMap({{" stop_flags" , " stop_flags_out" },
363368 {" seq_lens_this_time" , " seq_lens_this_time_out" },
364369 {" seq_lens_encoder" , " seq_lens_encoder_out" },
@@ -375,5 +380,6 @@ PD_BUILD_OP(step_paddle)
375380 {" used_list_len" , " used_list_len_out" },
376381 {" free_list" , " free_list_out" },
377382 {" free_list_len" , " free_list_len_out" },
378- {" input_ids" , " input_ids_out" }})
383+ {" input_ids" , " input_ids_out" },
384+ {" first_token_ids" , " first_token_ids_out" }})
379385 .SetKernelFn(PD_KERNEL(StepPaddle));
0 commit comments