1+ // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+ //
3+ // Licensed under the Apache License, Version 2.0 (the "License");
4+ // you may not use this file except in compliance with the License.
5+ // You may obtain a copy of the License at
6+ //
7+ // http://www.apache.org/licenses/LICENSE-2.0
8+ //
9+ // Unless required by applicable law or agreed to in writing, software
10+ // distributed under the License is distributed on an "AS IS" BASIS,
11+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ // See the License for the specific language governing permissions and
13+ // limitations under the License.
14+
15+ #include " helper.h"
16+
17+ template <typename T>
18+ __global__ void set_value_by_id (const int *seq_lens,
19+ const bool *stop_flags,
20+ const float *alibi_slopes,
21+ const int64_t *tgt_pos,
22+ T *output_data,
23+ int *sequence_lengths,
24+ int bs,
25+ int length,
26+ int num_head) {
27+ int bs_id = blockIdx .x ;
28+ int hid = threadIdx .x ;
29+ if (bs_id < bs) {
30+ T *output_data_now = output_data + bs_id * num_head * length + hid * length;
31+ float tgt_pos_now = static_cast <float >(tgt_pos[bs_id]);
32+ output_data_now[seq_lens[bs_id]] = static_cast <T>(tgt_pos_now * alibi_slopes[hid]);
33+ if (stop_flags[bs_id]) {
34+ sequence_lengths[bs_id] = 0 ;
35+ }
36+ }
37+ }
38+
39+ template <paddle::DataType D>
40+ std::vector<paddle::Tensor> set_mask_value (const paddle::Tensor& input_data,
41+ const paddle::Tensor& stop_flags,
42+ const paddle::Tensor& seq_lens,
43+ const paddle::Tensor& alibi_slopes,
44+ const paddle::Tensor& tgt_pos
45+ ) {
46+ typedef PDTraits<D> traits_;
47+ typedef typename traits_::DataType DataType_;
48+ typedef typename traits_::data_t data_t ;
49+
50+ PD_CHECK (seq_lens.dtype () == paddle::DataType::INT32);
51+ PD_CHECK (stop_flags.dtype () == paddle::DataType::BOOL);
52+ auto cu_stream = input_data.stream ();
53+ std::vector<int64_t > input_data_shape = input_data.shape ();
54+ std::vector<int64_t > seq_lens_shape = seq_lens.shape ();
55+ auto sequence_lengths = seq_lens.copy_to (seq_lens.place (), false );
56+
57+ int input_bs = input_data_shape[0 ];
58+ int length = input_data_shape[3 ];
59+ int seq_bs = seq_lens_shape[0 ];
60+ int num_head = alibi_slopes.shape ()[0 ];
61+
62+ int grid_size = input_bs;
63+ int block_size = num_head;
64+ set_value_by_id<<<grid_size, block_size, 0 , cu_stream>>> (seq_lens.data <int >(),
65+ stop_flags.data <bool >(),
66+ alibi_slopes.data <float >(),
67+ tgt_pos.data <int64_t >(),
68+ reinterpret_cast <DataType_*>(const_cast <data_t *>(input_data.data <data_t >())),
69+ sequence_lengths.data <int >(), seq_bs, length, num_head);
70+ return {sequence_lengths};
71+ }
72+
73+ std::vector<paddle::Tensor> SetMaskValue (const paddle::Tensor& input_data,
74+ const paddle::Tensor& stop_flags,
75+ const paddle::Tensor& seq_lens,
76+ const paddle::Tensor& alibi_slopes,
77+ const paddle::Tensor& tgt_pos) {
78+ switch (input_data.type ()) {
79+ case paddle::DataType::BFLOAT16: {
80+ return set_mask_value<paddle::DataType::BFLOAT16>(
81+ input_data,
82+ stop_flags,
83+ seq_lens,
84+ alibi_slopes,
85+ tgt_pos
86+ );
87+ }
88+ case paddle::DataType::FLOAT16: {
89+ return set_mask_value<paddle::DataType::FLOAT16>(
90+ input_data,
91+ stop_flags,
92+ seq_lens,
93+ alibi_slopes,
94+ tgt_pos
95+ );
96+ }
97+ case paddle::DataType::FLOAT32: {
98+ return set_mask_value<paddle::DataType::FLOAT32>(
99+ input_data,
100+ stop_flags,
101+ seq_lens,
102+ alibi_slopes,
103+ tgt_pos
104+ );
105+ }
106+ default : {
107+ PD_THROW (
108+ " NOT supported data type. "
109+ " Only float16, bfloat16 and float32 are supported. " );
110+ break ;
111+ }
112+ }
113+ }
114+
115+ std::vector<std::vector<int64_t >> SetMaskValueInferShape (const std::vector<int64_t >& input_data_shape,
116+ const std::vector<int64_t >& stop_flags_shape,
117+ const std::vector<int64_t >& seq_lens_shape,
118+ const std::vector<int64_t >& alibi_slopes_shape,
119+ const std::vector<int64_t >& tgt_pos) {
120+ return {seq_lens_shape};
121+ }
122+
123+ std::vector<paddle::DataType> SetMaskValueInferDtype (const paddle::DataType& input_data_dtype,
124+ const paddle::DataType& stop_flags_dtype,
125+ const paddle::DataType& seq_lens_dtype,
126+ const paddle::DataType& alibi_slopes_dtype,
127+ const paddle::DataType& tgt_pos_dtype) {
128+ return {seq_lens_dtype};
129+ }
130+
131+ PD_BUILD_OP (set_alibi_mask_value)
132+ .Inputs({" input_data" , " stop_flags" , " seq_lens" , " alibi_slopes" , " tgt_pos" })
133+ .Outputs({" sequence_lengths" })
134+ .SetKernelFn(PD_KERNEL(SetMaskValue))
135+ .SetInferShapeFn(PD_INFER_SHAPE(SetMaskValueInferShape))
136+ .SetInferDtypeFn(PD_INFER_DTYPE(SetMaskValueInferDtype));
0 commit comments