1+ // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+ //
3+ // Licensed under the Apache License, Version 2.0 (the "License");
4+ // you may not use this file except in compliance with the License.
5+ // You may obtain a copy of the License at
6+ //
7+ // http://www.apache.org/licenses/LICENSE-2.0
8+ //
9+ // Unless required by applicable law or agreed to in writing, software
10+ // distributed under the License is distributed on an "AS IS" BASIS,
11+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ // See the License for the specific language governing permissions and
13+ // limitations under the License.
14+
15+ #include " helper.h"
16+
17+ template <typename T>
18+ __global__ void set_preids_token_penalty_multi_scores_kernel (const bool *stop_flags,
19+ int64_t *pre_ids,
20+ const int64_t *input_ids,
21+ const int *seq_lens_encoder,
22+ const int *seq_lens_decoder,
23+ const int64_t *step_idx,
24+ const T *penalty_scores,
25+ const T *frequency_score,
26+ const T *presence_score,
27+ const float *temperatures,
28+ const int64_t *cur_len,
29+ const int64_t *min_len,
30+ const int64_t *eos_token_id,
31+ const int64_t *bad_words_list,
32+ int *repeat_times,
33+ T *logits,
34+ const int64_t bs,
35+ const int64_t length,
36+ const int64_t end_length,
37+ const int64_t length_id,
38+ const int64_t bad_words_length,
39+ const int64_t length_input_ids) {
40+ int bi = blockIdx .x ;
41+ T *logits_now = logits + bi * length;
42+ int tid = threadIdx .x ;
43+
44+ if (tid < bs && !stop_flags[tid]) {
45+ int64_t *pre_ids_now = pre_ids + tid * length;
46+ const int64_t *input_ids_now = input_ids + tid * length_input_ids;
47+ const int seq_len_dec = seq_lens_decoder[tid];
48+ const int seq_len_enc = seq_lens_encoder[tid];
49+ if (seq_len_dec == 0 && seq_len_enc == 0 ) return ; // stoped
50+
51+ const int step_idx_now = step_idx[bi];
52+ if (tid == 0 && step_idx_now >= 0 ) {
53+ if (seq_len_enc > 0 ) { // encoder, get last token accord to seq_lens_encoder
54+ pre_ids_now[step_idx_now] = input_ids_now[seq_len_enc - 1 ];
55+ } else { // decoedr, get first token
56+ pre_ids_now[step_idx_now] = input_ids_now[0 ];
57+ }
58+ }
59+ }
60+ __syncthreads ();
61+ // min_length process
62+ if (bi < bs) {
63+ if (cur_len[bi] < min_len[bi]) {
64+ if (tid < end_length) {
65+ logits_now[eos_token_id[tid]] = -1e10 ;
66+ }
67+ }
68+ }
69+ // update repeat_times
70+ int *repeat_times_now = repeat_times + bi * length;
71+ const int64_t *pre_ids_now = pre_ids + bi * length_id;
72+ for (int i = tid; i < length_id; i += blockDim .x ) {
73+ int64_t id = pre_ids_now[i];
74+ if (id < 0 ) break ;
75+ atomicAdd (&repeat_times_now[id], 1 );
76+ }
77+ __syncthreads ();
78+ // penalty_scores process
79+ float alpha = static_cast <float >(penalty_scores[bi]);
80+ float beta = static_cast <float >(frequency_score[bi]);
81+ float gamma = static_cast <float >(presence_score[bi]);
82+ for (int i = tid; i < length; i += blockDim .x ) {
83+ int times = repeat_times_now[i];
84+ float logit_now = static_cast <float >(logits_now[i]);
85+ if (times != 0 ) {
86+ logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
87+ logit_now = logit_now - times * beta - gamma;
88+ }
89+ logits_now[i] = static_cast <T>(logit_now / temperatures[bi]);
90+ }
91+ __syncthreads ();
92+ // bad_words process
93+ for (int i = tid; i < bad_words_length; i += blockDim .x ) {
94+ const int64_t bad_words_token_id = bad_words_list[i];
95+ if (bad_words_token_id >= length || bad_words_token_id < 0 ) continue ;
96+ logits_now[bad_words_token_id] = -1e10 ;
97+ }
98+ }
99+
100+ template <paddle::DataType D>
101+ void set_preids_token_penalty_multi_scores (const paddle::Tensor& pre_ids,
102+ const paddle::Tensor& input_ids,
103+ const paddle::Tensor& seq_lens_encoder,
104+ const paddle::Tensor& seq_lens_decoder,
105+ const paddle::Tensor& step_idx,
106+ const paddle::Tensor& stop_flags,
107+ const paddle::Tensor& logits,
108+ const paddle::Tensor& penalty_scores,
109+ const paddle::Tensor& frequency_score,
110+ const paddle::Tensor& presence_score,
111+ const paddle::Tensor& temperatures,
112+ const paddle::Tensor& bad_tokens,
113+ const paddle::Tensor& cur_len,
114+ const paddle::Tensor& min_len,
115+ const paddle::Tensor& eos_token_id) {
116+
117+ typedef PDTraits<D> traits_;
118+ typedef typename traits_::DataType DataType_;
119+ typedef typename traits_::data_t data_t ;
120+ auto cu_stream = logits.stream ();
121+ std::vector<int64_t > shape = logits.shape ();
122+ auto repeat_times = paddle::full (shape, 0 , paddle::DataType::INT32, pre_ids.place ());
123+ int64_t bs = shape[0 ];
124+ int64_t length = shape[1 ];
125+ int64_t length_id = pre_ids.shape ()[1 ];
126+ int64_t length_bad_words = bad_tokens.shape ()[0 ];
127+ int64_t length_input_ids = input_ids.shape ()[1 ];
128+
129+ int64_t end_length = eos_token_id.shape ()[0 ];
130+
131+ set_preids_token_penalty_multi_scores_kernel<DataType_><<<bs, 1024 , 0 , cu_stream>>> (
132+ stop_flags.data <bool >(),
133+ const_cast <int64_t *>(pre_ids.data <int64_t >()),
134+ input_ids.data <int64_t >(),
135+ seq_lens_encoder.data <int >(),
136+ seq_lens_decoder.data <int >(),
137+ step_idx.data <int64_t >(),
138+ reinterpret_cast <DataType_*>(const_cast <data_t *>(penalty_scores.data <data_t >())),
139+ reinterpret_cast <DataType_*>(const_cast <data_t *>(frequency_score.data <data_t >())),
140+ reinterpret_cast <DataType_*>(const_cast <data_t *>(presence_score.data <data_t >())),
141+ temperatures.data <float >(),
142+ cur_len.data <int64_t >(),
143+ min_len.data <int64_t >(),
144+ eos_token_id.data <int64_t >(),
145+ bad_tokens.data <int64_t >(),
146+ repeat_times.data <int >(),
147+ reinterpret_cast <DataType_*>(const_cast <data_t *>(logits.data <data_t >())),
148+ bs,
149+ length,
150+ end_length,
151+ length_id,
152+ length_bad_words,
153+ length_input_ids
154+ );
155+ }
156+
157+ void SetPreidsTokenPenaltyMultiScores (const paddle::Tensor& pre_ids,
158+ const paddle::Tensor& input_ids,
159+ const paddle::Tensor& seq_lens_encoder,
160+ const paddle::Tensor& seq_lens_decoder,
161+ const paddle::Tensor& step_idx,
162+ const paddle::Tensor& stop_flags,
163+ const paddle::Tensor& logits,
164+ const paddle::Tensor& penalty_scores,
165+ const paddle::Tensor& frequency_scores,
166+ const paddle::Tensor& presence_scores,
167+ const paddle::Tensor& temperatures,
168+ const paddle::Tensor& bad_tokens,
169+ const paddle::Tensor& cur_len,
170+ const paddle::Tensor& min_len,
171+ const paddle::Tensor& eos_token_id) {
172+
173+ switch (logits.type ()) {
174+ case paddle::DataType::BFLOAT16: {
175+ return set_preids_token_penalty_multi_scores<paddle::DataType::BFLOAT16>(
176+ pre_ids,
177+ input_ids,
178+ seq_lens_encoder,
179+ seq_lens_decoder,
180+ step_idx,
181+ stop_flags,
182+ logits,
183+ penalty_scores,
184+ frequency_scores,
185+ presence_scores,
186+ temperatures,
187+ bad_tokens,
188+ cur_len,
189+ min_len,
190+ eos_token_id
191+ );
192+ }
193+ case paddle::DataType::FLOAT16: {
194+ return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT16>(
195+ pre_ids,
196+ input_ids,
197+ seq_lens_encoder,
198+ seq_lens_decoder,
199+ step_idx,
200+ stop_flags,
201+ logits,
202+ penalty_scores,
203+ frequency_scores,
204+ presence_scores,
205+ temperatures,
206+ bad_tokens,
207+ cur_len,
208+ min_len,
209+ eos_token_id
210+ );
211+ }
212+ case paddle::DataType::FLOAT32: {
213+ return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT32>(
214+ pre_ids,
215+ input_ids,
216+ seq_lens_encoder,
217+ seq_lens_decoder,
218+ step_idx,
219+ stop_flags,
220+ logits,
221+ penalty_scores,
222+ frequency_scores,
223+ presence_scores,
224+ temperatures,
225+ bad_tokens,
226+ cur_len,
227+ min_len,
228+ eos_token_id
229+ );
230+ }
231+ default : {
232+ PD_THROW (
233+ " NOT supported data type. "
234+ " Only float16, bfloat16 and float32 are supported. " );
235+ break ;
236+ }
237+ }
238+ }
239+
240+ PD_BUILD_OP (set_preids_token_penalty_multi_scores)
241+ .Inputs({" pre_ids" ,
242+ " input_ids" ,
243+ " seq_lens_encoder" ,
244+ " seq_lens_decoder" ,
245+ " step_idx" ,
246+ " stop_flags" ,
247+ " logits" ,
248+ " penalty_scores" ,
249+ " frequency_scores" ,
250+ " presence_scores" ,
251+ " temperatures" ,
252+ " bad_tokens" ,
253+ " cur_len" ,
254+ " min_len" ,
255+ " eos_token_id" })
256+ .Outputs({" logits_out" , " pre_ids_out" })
257+ .SetInplaceMap({{" logits" , " logits_out" }, {" pre_ids" , " pre_ids_out" }})
258+ .SetKernelFn(PD_KERNEL(SetPreidsTokenPenaltyMultiScores));
0 commit comments