1+ #include " paddle/extension.h"
2+
3+ __global__ void RemovePaddingV2 (int64_t *output_data,
4+ const int64_t *input_data,
5+ const int *seq_lens,
6+ const int *cum_offsets,
7+ const int sequence_length) {
8+ const int bi = blockIdx .x ;
9+ const int tid = threadIdx .x ;
10+
11+ for (int i = tid; i < seq_lens[bi]; i += blockDim .x ) {
12+ const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
13+ const int src_seq_id = bi * sequence_length + i;
14+ output_data[tgt_seq_id] = input_data[src_seq_id];
15+ }
16+ }
17+
18+ __global__ void GetPaddingOffsetKernelV2 (int *padding_offset,
19+ int *cum_offsets_out,
20+ int *cu_seqlens_q,
21+ int *cu_seqlens_k,
22+ const int *cum_offsets,
23+ const int *seq_lens,
24+ const int max_seq_len) {
25+ // get padding offset of each batch
26+ const int bi = blockIdx .x ;
27+ const int ti = threadIdx .x ;
28+ int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1 ];
29+ for (int i = ti; i < seq_lens[bi]; i += blockDim .x ) {
30+ padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
31+ }
32+ if (ti == 0 ) {
33+ cum_offsets_out[bi] = cum_offset;
34+ int cum_seq_len = (bi + 1 ) * max_seq_len - cum_offsets[bi];
35+ cu_seqlens_q[bi + 1 ] = cum_seq_len;
36+ cu_seqlens_k[bi + 1 ] = cum_seq_len;
37+ }
38+ }
39+
40+
41+ std::vector<paddle::Tensor> GetPaddingOffsetV2 (const paddle::Tensor& input_ids,
42+ const paddle::Tensor& cum_offsets,
43+ const paddle::Tensor& token_num,
44+ const paddle::Tensor& seq_len) {
45+ auto cu_stream = input_ids.stream ();
46+ std::vector<int64_t > input_ids_shape = input_ids.shape ();
47+ const int bsz = seq_len.shape ()[0 ];
48+ const int seq_length = input_ids_shape[1 ];
49+ auto cum_offsets_out = cum_offsets.copy_to (cum_offsets.place (), false );
50+ auto cpu_token_num = token_num.copy_to (paddle::CPUPlace (), false );
51+
52+ const int token_num_data = cpu_token_num.data <int64_t >()[0 ];
53+ auto x_remove_padding = paddle::full ({token_num_data}, 0 , paddle::DataType::INT64, input_ids.place ());
54+ auto padding_offset = paddle::full ({token_num_data}, 0 , paddle::DataType::INT32, input_ids.place ());
55+ auto cu_seqlens_q = paddle::full ({bsz + 1 }, 0 , paddle::DataType::INT32, input_ids.place ());
56+ auto cu_seqlens_k = paddle::full ({bsz + 1 }, 0 , paddle::DataType::INT32, input_ids.place ());
57+ int blockSize = min ((token_num_data + 32 - 1 ) / 32 * 32 , 128 );
58+ GetPaddingOffsetKernelV2<<<bsz, 128 , 0 , cu_stream>>> (
59+ padding_offset.data <int >(),
60+ cum_offsets_out.data <int >(),
61+ cu_seqlens_q.data <int >(),
62+ cu_seqlens_k.data <int >(),
63+ cum_offsets.data <int >(),
64+ seq_len.data <int >(),
65+ seq_length);
66+ RemovePaddingV2<<<bsz, blockSize, 0 , cu_stream>>> (
67+ x_remove_padding.data <int64_t >(),
68+ input_ids.data <int64_t >(),
69+ seq_len.data <int >(),
70+ cum_offsets_out.data <int >(),
71+ seq_length);
72+ return {x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num};
73+ }
74+
75+ std::vector<std::vector<int64_t >> GetPaddingOffsetV2InferShape (const std::vector<int64_t >& input_ids_shape,
76+ const std::vector<int64_t >& cum_offsets_shape,
77+ const std::vector<int64_t >& token_num_shape,
78+ const std::vector<int64_t >& seq_len_shape) {
79+ int64_t bsz = seq_len_shape[0 ];
80+ int64_t seq_len = input_ids_shape[1 ];
81+ return {{-1 }, {bsz}, {-1 }, {bsz + 1 }, {bsz + 1 }};
82+ }
83+
84+ std::vector<paddle::DataType> GetPaddingOffsetV2InferDtype (const paddle::DataType& input_ids_dtype,
85+ const paddle::DataType& cum_offsets_dtype,
86+ const paddle::DataType& token_num_dtype,
87+ const paddle::DataType& seq_len_dtype) {
88+ return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
89+ }
90+
91+ PD_BUILD_OP (get_padding_offset_v2)
92+ .Inputs({" input_ids" , " token_num" , " cum_offsets" , " seq_len" })
93+ .Outputs({" x_remove_padding" , " cum_offsets_out" , " padding_offset" , " cu_seqlens_q" , " cu_seqlens_k" })
94+ .SetKernelFn(PD_KERNEL(GetPaddingOffsetV2))
95+ .SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetV2InferShape))
96+ .SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetV2InferDtype));
0 commit comments