1+ // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+ //
3+ // Licensed under the Apache License, Version 2.0 (the "License");
4+ // you may not use this file except in compliance with the License.
5+ // You may obtain a copy of the License at
6+ //
7+ // http://www.apache.org/licenses/LICENSE-2.0
8+ //
9+ // Unless required by applicable law or agreed to in writing, software
10+ // distributed under the License is distributed on an "AS IS" BASIS,
11+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ // See the License for the specific language governing permissions and
13+ // limitations under the License.
14+
15+ #include " helper.h"
16+ #include " paddle/extension.h"
17+
18+ template <int THREADBLOCK_SIZE, bool EAGLE>
19+ __global__ void draft_model_preprocess_kernel (
20+ int64_t * draft_tokens,
21+ int64_t * input_ids,
22+ bool * stop_flags,
23+ int * seq_lens_this_time,
24+ int * seq_lens_encoder,
25+ int * seq_lens_decoder,
26+ int64_t * step_idx,
27+ int * first_token_record,
28+ bool * not_need_stop,
29+ const int64_t * accept_tokens,
30+ const int * accept_num,
31+ const int * base_model_seq_lens_encoder,
32+ const int * base_model_seq_lens_decoder,
33+ const int64_t * base_model_step_idx,
34+ const bool * base_model_stop_flags,
35+ int64_t * base_model_draft_tokens,
36+ const int bsz,
37+ const int max_draft_token,
38+ const int accept_tokens_len,
39+ const int draft_tokens_len,
40+ const int input_ids_len,
41+ const int base_model_draft_tokens_len) {
42+ typedef cub::BlockReduce<int64_t , THREADBLOCK_SIZE> BlockReduce;
43+ __shared__ typename BlockReduce::TempStorage temp_storage;
44+ int64_t not_stop_flag = 0 ;
45+
46+ int tid = threadIdx .x ;
47+
48+ if (tid < bsz) {
49+ auto base_model_step_idx_now = base_model_step_idx[tid];
50+ auto * accept_tokens_now = accept_tokens + tid * accept_tokens_len;
51+ auto * draft_tokens_now = draft_tokens + tid * draft_tokens_len;
52+ auto accept_num_now = accept_num[tid];
53+ auto * input_ids_now = input_ids + tid * input_ids_len;
54+ auto * base_model_draft_tokens_now =
55+ base_model_draft_tokens + tid * base_model_draft_tokens_len;
56+ #pragma unroll
57+ for (int i = 1 ; i < base_model_draft_tokens_len; i++) {
58+ base_model_draft_tokens_now[i] = -1 ;
59+ }
60+
61+ if (!base_model_stop_flags[tid]) {
62+ not_stop_flag = 1 ;
63+ // 1. first token
64+ if (base_model_step_idx_now == 0 ) {
65+ seq_lens_this_time[tid] = 0 ;
66+ not_stop_flag = 0 ;
67+ } else if (base_model_step_idx_now == 1 && first_token_record[tid] > 0 ) {
68+ // Can be extended to first few tokens
69+ seq_lens_encoder[tid] = first_token_record[tid];
70+ first_token_record[tid] = -1 ;
71+ stop_flags[tid] = false ;
72+ int64_t base_model_first_token = accept_tokens_now[0 ];
73+ int position = base_model_seq_lens_decoder[tid];
74+ if (EAGLE) {
75+ input_ids_now[position - 1 ] = base_model_first_token;
76+ seq_lens_this_time[tid] = base_model_seq_lens_decoder[tid];
77+ } else {
78+ input_ids_now[position] = base_model_first_token;
79+ seq_lens_this_time[tid] = base_model_seq_lens_decoder[tid] + 1 ;
80+ }
81+ } else if (accept_num_now <=
82+ max_draft_token) /* Accept partial draft tokens*/ {
83+ // Base Model reject stop
84+ if (stop_flags[tid]) {
85+ stop_flags[tid] = false ;
86+ seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
87+ step_idx[tid] = base_model_step_idx[tid];
88+ } else {
89+ seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
90+ step_idx[tid] -= max_draft_token - accept_num_now;
91+ }
92+ int64_t modified_token = accept_tokens_now[accept_num_now - 1 ];
93+ draft_tokens_now[0 ] = modified_token;
94+ seq_lens_this_time[tid] = 1 ;
95+
96+ } else /* Accept all draft tokens*/ {
97+ draft_tokens_now[1 ] = accept_tokens_now[max_draft_token];
98+ seq_lens_this_time[tid] = 2 ;
99+ }
100+ } else {
101+ stop_flags[tid] = true ;
102+ seq_lens_this_time[tid] = 0 ;
103+ seq_lens_decoder[tid] = 0 ;
104+ }
105+ }
106+ __syncthreads ();
107+ int64_t not_stop_flag_sum = BlockReduce (temp_storage).Sum (not_stop_flag);
108+ if (tid == 0 ) {
109+ not_need_stop[0 ] = not_stop_flag_sum > 0 ;
110+ }
111+ }
112+
113+
114+ void DraftModelPreprocess (const paddle::Tensor& draft_tokens,
115+ const paddle::Tensor& input_ids,
116+ const paddle::Tensor& stop_flags,
117+ const paddle::Tensor& seq_lens_this_time,
118+ const paddle::Tensor& seq_lens_encoder,
119+ const paddle::Tensor& seq_lens_decoder,
120+ const paddle::Tensor& step_idx,
121+ const paddle::Tensor& first_token_record,
122+ const paddle::Tensor& not_need_stop,
123+ const paddle::Tensor& accept_tokens,
124+ const paddle::Tensor& accept_num,
125+ const paddle::Tensor& base_model_seq_lens_encoder,
126+ const paddle::Tensor& base_model_seq_lens_decoder,
127+ const paddle::Tensor& base_model_step_idx,
128+ const paddle::Tensor& base_model_stop_flags,
129+ const paddle::Tensor& base_model_draft_tokens,
130+ const int max_draft_token,
131+ const std::string& draft_type) {
132+ int real_bsz = seq_lens_this_time.shape ()[0 ];
133+ int accept_tokens_len = accept_tokens.shape ()[1 ];
134+ int input_ids_len = input_ids.shape ()[1 ];
135+ int draft_tokens_len = draft_tokens.shape ()[1 ];
136+ auto cu_stream = seq_lens_this_time.stream ();
137+ constexpr int BlockSize = 256 ;
138+ int base_model_draft_tokens_len = base_model_draft_tokens.shape ()[1 ];
139+ auto not_need_stop_gpu =
140+ not_need_stop.copy_to (seq_lens_this_time.place (), false );
141+
142+
143+ if (draft_type == " eagle" ) {
144+ draft_model_preprocess_kernel<BlockSize, true >
145+ <<<1 , BlockSize, 0 , cu_stream>>> (
146+ const_cast <int64_t *>(draft_tokens.data <int64_t >()),
147+ const_cast <int64_t *>(input_ids.data <int64_t >()),
148+ const_cast <bool *>(stop_flags.data <bool >()),
149+ const_cast <int *>(seq_lens_this_time.data <int >()),
150+ const_cast <int *>(seq_lens_encoder.data <int >()),
151+ const_cast <int *>(seq_lens_decoder.data <int >()),
152+ const_cast <int64_t *>(step_idx.data <int64_t >()),
153+ const_cast <int *>(first_token_record.data <int >()),
154+ const_cast <bool *>(not_need_stop_gpu.data <bool >()),
155+ accept_tokens.data <int64_t >(),
156+ accept_num.data <int >(),
157+ base_model_seq_lens_encoder.data <int >(),
158+ base_model_seq_lens_decoder.data <int >(),
159+ base_model_step_idx.data <int64_t >(),
160+ base_model_stop_flags.data <bool >(),
161+ const_cast <int64_t *>(base_model_draft_tokens.data <int64_t >()),
162+ real_bsz,
163+ max_draft_token,
164+ accept_tokens_len,
165+ draft_tokens_len,
166+ input_ids_len,
167+ base_model_draft_tokens_len);
168+ } else {
169+ draft_model_preprocess_kernel<BlockSize, false >
170+ <<<1 , BlockSize, 0 , cu_stream>>> (
171+ const_cast <int64_t *>(draft_tokens.data <int64_t >()),
172+ const_cast <int64_t *>(input_ids.data <int64_t >()),
173+ const_cast <bool *>(stop_flags.data <bool >()),
174+ const_cast <int *>(seq_lens_this_time.data <int >()),
175+ const_cast <int *>(seq_lens_encoder.data <int >()),
176+ const_cast <int *>(seq_lens_decoder.data <int >()),
177+ const_cast <int64_t *>(step_idx.data <int64_t >()),
178+ const_cast <int *>(first_token_record.data <int >()),
179+ const_cast <bool *>(not_need_stop_gpu.data <bool >()),
180+ accept_tokens.data <int64_t >(),
181+ accept_num.data <int >(),
182+ base_model_seq_lens_encoder.data <int >(),
183+ base_model_seq_lens_decoder.data <int >(),
184+ base_model_step_idx.data <int64_t >(),
185+ base_model_stop_flags.data <bool >(),
186+ const_cast <int64_t *>(base_model_draft_tokens.data <int64_t >()),
187+ real_bsz,
188+ max_draft_token,
189+ accept_tokens_len,
190+ draft_tokens_len,
191+ input_ids_len,
192+ base_model_draft_tokens_len);
193+ }
194+
195+
196+ auto not_need_stop_cpu =
197+ not_need_stop_gpu.copy_to (not_need_stop.place (), false );
198+ bool * not_need_stop_data = const_cast <bool *>(not_need_stop.data <bool >());
199+ not_need_stop_data[0 ] = not_need_stop_cpu.data <bool >()[0 ];
200+ }
201+
202+
203+ PD_BUILD_OP (draft_model_preprocess)
204+ .Inputs({" draft_tokens" ,
205+ " input_ids" ,
206+ " stop_flags" ,
207+ " seq_lens_this_time" ,
208+ " seq_lens_encoder" ,
209+ " seq_lens_decoder" ,
210+ " step_idx" ,
211+ " first_token_record" ,
212+ " not_need_stop" ,
213+ " accept_tokens" ,
214+ " accept_num" ,
215+ " base_model_seq_lens_encoder" ,
216+ " base_model_seq_lens_decoder" ,
217+ " base_model_step_idx" ,
218+ " base_model_stop_flags" ,
219+ " base_model_draft_tokens" })
220+ .Outputs({" draft_tokens_out" ,
221+ " input_ids_out" ,
222+ " stop_flags_out" ,
223+ " seq_lens_this_time_out" ,
224+ " seq_lens_encoder_out" ,
225+ " seq_lens_decoder_out" ,
226+ " step_idx_out" ,
227+ " not_need_stop_out" ,
228+ " first_token_record_out" })
229+ .Attrs({" max_draft_token: int" , " draft_type: std::string" })
230+ .SetInplaceMap({{" draft_tokens" , " draft_tokens_out" },
231+ {" input_ids" , " input_ids_out" },
232+ {" stop_flags" , " stop_flags_out" },
233+ {" seq_lens_this_time" , " seq_lens_this_time_out" },
234+ {" seq_lens_encoder" , " seq_lens_encoder_out" },
235+ {" seq_lens_decoder" , " seq_lens_decoder_out" },
236+ {" step_idx" , " step_idx_out" },
237+ {" not_need_stop" , " not_need_stop_out" },
238+ {" first_token_record" , " first_token_record_out" }})
239+ .SetKernelFn(PD_KERNEL(DraftModelPreprocess));
0 commit comments