Skip to content

Commit c5d8d5b

Browse files
[LLM] Support block_attention/cachekv quant for llama (#7649)
* support blha and cache kv quant * lint * fix unit test * fix infer when blha is on * code refine * add docs and fix ops * merge blha read res in predictor * finish docs * add docs and unittest * add unittest * migrate read res
1 parent 5a32534 commit c5d8d5b

29 files changed

+3380
-79
lines changed

.github/codecov.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ coverage:
1010
threshold: 1% # Allow the coverage to drop by 1%, and posting a success status.
1111
patch:
1212
default:
13-
target: 80% # lines adjusted Coverage < 80% CI will fail
13+
target: 80% # lines adjusted Coverage < 80% CI will fail

csrc/generation/get_output.cc

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <stdio.h>
16+
#include <string.h>
17+
#include <sys/ipc.h>
18+
#include <sys/msg.h>
19+
#include <sys/types.h>
20+
#include "paddle/extension.h"
21+
22+
#define MAX_BSZ 512
23+
24+
struct msgdata {
25+
long mtype;
26+
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
27+
};
28+
29+
void GetOutput(const paddle::Tensor& x,
30+
int64_t rank_id,
31+
bool wait_flag) {
32+
if (rank_id > 0) return;
33+
34+
static struct msgdata msg_rcv;
35+
36+
static key_t key = ftok("./", 1);
37+
38+
static int msgid = msgget(key, IPC_CREAT | 0666);
39+
40+
int64_t *out_data = const_cast<int64_t*>(x.data<int64_t>());
41+
int ret = -1;
42+
if (!wait_flag) {
43+
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
44+
} else {
45+
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
46+
}
47+
if(ret == -1)
48+
{
49+
// read none
50+
out_data[0] = -2;
51+
out_data[1] = 0;
52+
return;
53+
}
54+
55+
int bsz = msg_rcv.mtext[1];
56+
57+
for (int64_t i = 0; i < bsz + 2; i++) {
58+
out_data[i] = (int64_t)msg_rcv.mtext[i];
59+
}
60+
return;
61+
}
62+
63+
PD_BUILD_OP(get_output)
64+
.Inputs({"x"})
65+
.Attrs({"rank_id: int64_t",
66+
"wait_flag: bool"})
67+
.Outputs({"x_out"})
68+
.SetInplaceMap({{"x", "x_out"}})
69+
.SetKernelFn(PD_KERNEL(GetOutput));
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#include "paddle/extension.h"
2+
3+
__global__ void RemovePaddingV2(int64_t *output_data,
4+
const int64_t *input_data,
5+
const int *seq_lens,
6+
const int *cum_offsets,
7+
const int sequence_length) {
8+
const int bi = blockIdx.x;
9+
const int tid = threadIdx.x;
10+
11+
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
12+
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
13+
const int src_seq_id = bi * sequence_length + i;
14+
output_data[tgt_seq_id] = input_data[src_seq_id];
15+
}
16+
}
17+
18+
__global__ void GetPaddingOffsetKernelV2(int *padding_offset,
19+
int *cum_offsets_out,
20+
int *cu_seqlens_q,
21+
int *cu_seqlens_k,
22+
const int *cum_offsets,
23+
const int *seq_lens,
24+
const int max_seq_len) {
25+
// get padding offset of each batch
26+
const int bi = blockIdx.x;
27+
const int ti = threadIdx.x;
28+
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
29+
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
30+
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
31+
}
32+
if (ti == 0) {
33+
cum_offsets_out[bi] = cum_offset;
34+
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
35+
cu_seqlens_q[bi + 1] = cum_seq_len;
36+
cu_seqlens_k[bi + 1] = cum_seq_len;
37+
}
38+
}
39+
40+
41+
std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
42+
const paddle::Tensor& cum_offsets,
43+
const paddle::Tensor& token_num,
44+
const paddle::Tensor& seq_len) {
45+
auto cu_stream = input_ids.stream();
46+
std::vector<int64_t> input_ids_shape = input_ids.shape();
47+
const int bsz = seq_len.shape()[0];
48+
const int seq_length = input_ids_shape[1];
49+
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
50+
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
51+
52+
const int token_num_data = cpu_token_num.data<int64_t>()[0];
53+
auto x_remove_padding = paddle::full({token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
54+
auto padding_offset = paddle::full({token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
55+
auto cu_seqlens_q = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
56+
auto cu_seqlens_k = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
57+
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
58+
GetPaddingOffsetKernelV2<<<bsz, 128, 0, cu_stream>>>(
59+
padding_offset.data<int>(),
60+
cum_offsets_out.data<int>(),
61+
cu_seqlens_q.data<int>(),
62+
cu_seqlens_k.data<int>(),
63+
cum_offsets.data<int>(),
64+
seq_len.data<int>(),
65+
seq_length);
66+
RemovePaddingV2<<<bsz, blockSize, 0, cu_stream>>>(
67+
x_remove_padding.data<int64_t>(),
68+
input_ids.data<int64_t>(),
69+
seq_len.data<int>(),
70+
cum_offsets_out.data<int>(),
71+
seq_length);
72+
return {x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num};
73+
}
74+
75+
std::vector<std::vector<int64_t>> GetPaddingOffsetV2InferShape(const std::vector<int64_t>& input_ids_shape,
76+
const std::vector<int64_t>& cum_offsets_shape,
77+
const std::vector<int64_t>& token_num_shape,
78+
const std::vector<int64_t>& seq_len_shape) {
79+
int64_t bsz = seq_len_shape[0];
80+
int64_t seq_len = input_ids_shape[1];
81+
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
82+
}
83+
84+
std::vector<paddle::DataType> GetPaddingOffsetV2InferDtype(const paddle::DataType& input_ids_dtype,
85+
const paddle::DataType& cum_offsets_dtype,
86+
const paddle::DataType& token_num_dtype,
87+
const paddle::DataType& seq_len_dtype) {
88+
return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
89+
}
90+
91+
PD_BUILD_OP(get_padding_offset_v2)
92+
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
93+
.Outputs({"x_remove_padding", "cum_offsets_out", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"})
94+
.SetKernelFn(PD_KERNEL(GetPaddingOffsetV2))
95+
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetV2InferShape))
96+
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetV2InferDtype));
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#include "helper.h"
2+
3+
template <typename T, int VecSize>
4+
__global__ void RebuildPaddingV2Kernel(T *output_data,
5+
const T *input_data,
6+
const int *cum_offsets,
7+
const int *seq_len_decoder,
8+
const int *seq_len_encoder,
9+
const int seq_len,
10+
const int dim_embed,
11+
const int elem_nums) {
12+
using LoadT = AlignedVector<T, VecSize>;
13+
LoadT src_vec;
14+
const int global_idx = blockDim.x * blockIdx.x + threadIdx.x;
15+
for (int i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) {
16+
const int bi = i / dim_embed;
17+
const int bias_idx = i % dim_embed;
18+
int seq_id = 0;
19+
// just encoder or stop, get last token; just decoder, get first token.
20+
if (seq_len_decoder[bi] == 0) {
21+
if (seq_len_encoder[bi] != 0) {
22+
seq_id = seq_len_encoder[bi] - 1;
23+
} else {
24+
return;
25+
}
26+
}
27+
const int ori_token_idx = bi * seq_len - cum_offsets[bi] + seq_id;
28+
const int src_offset = ori_token_idx * dim_embed + bias_idx;
29+
Load<T, VecSize>(&input_data[src_offset], &src_vec);
30+
Store<T, VecSize>(src_vec, &output_data[i]);
31+
}
32+
}
33+
34+
template <paddle::DataType D>
35+
std::vector<paddle::Tensor> rebuild_padding_v2(const paddle::Tensor& tmp_out, // [token_num, dim_embed]
36+
const paddle::Tensor& cum_offsets, // [bsz, 1]
37+
const paddle::Tensor& seq_lens_decoder,
38+
const paddle::Tensor& seq_lens_encoder,
39+
int max_input_length) {
40+
typedef PDTraits<D> traits_;
41+
typedef typename traits_::DataType DataType_;
42+
typedef typename traits_::data_t data_t;
43+
44+
auto cu_stream = tmp_out.stream();
45+
std::vector<int64_t> tmp_out_shape = tmp_out.shape();
46+
const int token_num = tmp_out_shape[0];
47+
const int dim_embed = tmp_out_shape[1];
48+
const int bsz = cum_offsets.shape()[0];
49+
auto out = paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place());
50+
constexpr int PackSize = VEC_16B / sizeof(DataType_);
51+
int elem_nums = out.numel();
52+
int pack_num = elem_nums / PackSize;
53+
const int blocksize = 128;
54+
const int grid_size = (pack_num + blocksize - 1) / blocksize;
55+
RebuildPaddingV2Kernel<DataType_, PackSize><<<grid_size, blocksize, 0, tmp_out.stream()>>>(
56+
reinterpret_cast<DataType_*>(out.data<data_t>()),
57+
reinterpret_cast<DataType_*>(const_cast<data_t*>(tmp_out.data<data_t>())),
58+
cum_offsets.data<int>(),
59+
seq_lens_decoder.data<int>(),
60+
seq_lens_encoder.data<int>(),
61+
max_input_length,
62+
dim_embed,
63+
elem_nums);
64+
return {out};
65+
}
66+
67+
std::vector<paddle::Tensor> RebuildPaddingV2(const paddle::Tensor& tmp_out, // [token_num, dim_embed]
68+
const paddle::Tensor& cum_offsets, // [bsz, 1]
69+
const paddle::Tensor& seq_lens_decoder,
70+
const paddle::Tensor& seq_lens_encoder,
71+
int max_input_length) {
72+
switch (tmp_out.type()) {
73+
case paddle::DataType::BFLOAT16: {
74+
return rebuild_padding_v2<paddle::DataType::BFLOAT16>(
75+
tmp_out,
76+
cum_offsets,
77+
seq_lens_decoder,
78+
seq_lens_encoder,
79+
max_input_length
80+
);
81+
}
82+
case paddle::DataType::FLOAT16: {
83+
return rebuild_padding_v2<paddle::DataType::FLOAT16>(
84+
tmp_out,
85+
cum_offsets,
86+
seq_lens_decoder,
87+
seq_lens_encoder,
88+
max_input_length
89+
);
90+
}
91+
case paddle::DataType::FLOAT32: {
92+
return rebuild_padding_v2<paddle::DataType::FLOAT32>(
93+
tmp_out,
94+
cum_offsets,
95+
seq_lens_decoder,
96+
seq_lens_encoder,
97+
max_input_length
98+
);
99+
}
100+
default: {
101+
PD_THROW(
102+
"NOT supported data type. "
103+
"Only float16, bfloat16 and float32 are supported. ");
104+
break;
105+
}
106+
}
107+
}
108+
109+
std::vector<std::vector<int64_t>> RebuildPaddingV2InferShape(const std::vector<int64_t>& tmp_out_shape,
110+
const std::vector<int64_t>& cum_offsets_shape,
111+
const std::vector<int64_t>& seq_lens_decoder_shape,
112+
const std::vector<int64_t>& seq_lens_encoder_shape) {
113+
int64_t bsz = cum_offsets_shape[0];
114+
int64_t dim_embed = tmp_out_shape[1];
115+
return {{bsz, dim_embed}};
116+
}
117+
118+
std::vector<paddle::DataType> RebuildPaddingV2InferDtype(const paddle::DataType& tmp_out_dtype,
119+
const paddle::DataType& cum_offsets_dtype,
120+
const paddle::DataType& seq_lens_decoder_dtype,
121+
const paddle::DataType& seq_lens_encoder_dtype) {
122+
return {tmp_out_dtype};
123+
}
124+
125+
PD_BUILD_OP(rebuild_padding_v2)
126+
.Inputs({"tmp_out", "cum_offsets", "seq_lens_decoder", "seq_lens_encoder"})
127+
.Outputs({"out"})
128+
.Attrs({"max_input_length: int"})
129+
.SetKernelFn(PD_KERNEL(RebuildPaddingV2))
130+
.SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingV2InferShape))
131+
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingV2InferDtype));
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include "paddle/extension.h"
2+
3+
void SetStopValue(const paddle::Tensor& not_need_stop) {
4+
bool *stop_data = const_cast<bool*>(not_need_stop.data<bool>());
5+
stop_data[0] = true;
6+
}
7+
8+
PD_BUILD_OP(reset_stop_value)
9+
.Inputs({"not_need_stop"})
10+
.Outputs({"not_need_stop_out"})
11+
.SetInplaceMap({{"not_need_stop", "not_need_stop_out"}})
12+
.SetKernelFn(PD_KERNEL(SetStopValue));
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <stdio.h>
16+
#include <string.h>
17+
#include <sys/ipc.h>
18+
#include <sys/msg.h>
19+
#include <sys/types.h>
20+
#include "paddle/extension.h"
21+
22+
#define MAX_BSZ 512
23+
24+
struct msgdata {
25+
long mtype;
26+
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
27+
};
28+
29+
void SaveOutMmsg(const paddle::Tensor& x,
30+
const paddle::Tensor& not_need_stop,
31+
int64_t rank_id) {
32+
if (rank_id > 0) return;
33+
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
34+
int64_t *x_data = x_cpu.data<int64_t>();
35+
static struct msgdata msg_sed;
36+
static key_t key = ftok("./", 1);
37+
static int msgid = msgget(key, IPC_CREAT | 0666);
38+
39+
msg_sed.mtype = 1;
40+
bool not_need_stop_data = not_need_stop.data<bool>()[0];
41+
msg_sed.mtext[0] = not_need_stop_data ? 1 : -1;
42+
int bsz = x.shape()[0];
43+
msg_sed.mtext[1] = bsz;
44+
for (int i = 2; i < bsz + 2; i++) {
45+
msg_sed.mtext[i] = (int)x_data[i - 2];
46+
}
47+
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) {
48+
// printf("full msg buffer\n");
49+
}
50+
return;
51+
}
52+
53+
PD_BUILD_OP(save_output)
54+
.Inputs({"x", "not_need_stop"})
55+
.Attrs({"rank_id: int64_t"})
56+
.Outputs({"x_out"})
57+
.SetInplaceMap({{"x", "x_out"}})
58+
.SetKernelFn(PD_KERNEL(SaveOutMmsg));

0 commit comments

Comments
 (0)