Skip to content

Commit 6904bb7

Browse files
ming1753fxfxfxfxfxfxfxfx
authored andcommitted
[bug fixes] fix block attention out of bound risk (PaddlePaddle#69001)
1 parent 284dbf4 commit 6904bb7

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

paddle/phi/kernels/fusion/gpu/block_attn.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4175,7 +4175,8 @@ __global__ void GetDecoderTensorKernel(const T *qkv_out,
41754175
const int kv_head_num,
41764176
const int seq_len,
41774177
const int dim_head,
4178-
const int elem_nums) {
4178+
const int elem_nums,
4179+
const int qkv_out_nums) {
41794180
using LoadT = phi::AlignedVector<T, VecSize>;
41804181
LoadT src_vec;
41814182
const int32_t fused_hidden_size = (q_head_num + 2 * kv_head_num) * dim_head;
@@ -4186,6 +4187,7 @@ __global__ void GetDecoderTensorKernel(const T *qkv_out,
41864187
const int bias_idx = i % fused_hidden_size;
41874188
const int ori_token_idx = bi * seq_len - cum_offsets[bi];
41884189
const int src_offset = ori_token_idx * fused_hidden_size + bias_idx;
4190+
if (src_offset >= qkv_out_nums) continue;
41894191
phi::Load<T, VecSize>(&qkv_out[src_offset], &src_vec);
41904192
phi::Store<T, VecSize>(src_vec, &qkv_out_decoder[i]);
41914193
}
@@ -4234,6 +4236,7 @@ void GetDecoderTensor(const phi::GPUContext &dev_ctx,
42344236
// kv_num_head + q_num_head, dim_head] rope: [2, bsz, 1, seq_len, dim_head] ->
42354237
// [2, bsz, 1, 1, dim_head]
42364238
int elem_nums = qkv_out_decoder->numel();
4239+
int qkv_out_nums = qkv_out.numel();
42374240
constexpr int PackSize = VEC_16B / sizeof(T);
42384241
PADDLE_ENFORCE_EQ(
42394242
dim_head % PackSize,
@@ -4255,7 +4258,8 @@ void GetDecoderTensor(const phi::GPUContext &dev_ctx,
42554258
kv_num_head,
42564259
seq_len,
42574260
dim_head,
4258-
elem_nums);
4261+
elem_nums,
4262+
qkv_out_nums);
42594263
if (rope_out_emb) {
42604264
elem_nums = rope_out_emb->numel() / 2;
42614265
pack_num = elem_nums / PackSize;

0 commit comments

Comments
 (0)