Skip to content

Commit 2bd5a46

Browse files
authored
Merge branch 'PaddlePaddle:develop' into jaywan/fix_sync_memcpy
2 parents 5dddf0c + 4a7665d commit 2bd5a46

File tree

27 files changed

+959
-194
lines changed

27 files changed

+959
-194
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "helper.h"
16+
17+
template <typename T>
18+
__global__ void set_value_by_id(const int *seq_lens,
19+
const bool *stop_flags,
20+
const float *alibi_slopes,
21+
const int64_t *tgt_pos,
22+
T *output_data,
23+
int *sequence_lengths,
24+
int bs,
25+
int length,
26+
int num_head) {
27+
int bs_id = blockIdx.x;
28+
int hid = threadIdx.x;
29+
if (bs_id < bs) {
30+
T *output_data_now = output_data + bs_id * num_head * length + hid * length;
31+
float tgt_pos_now = static_cast<float>(tgt_pos[bs_id]);
32+
output_data_now[seq_lens[bs_id]] = static_cast<T>(tgt_pos_now * alibi_slopes[hid]);
33+
if (stop_flags[bs_id]) {
34+
sequence_lengths[bs_id] = 0;
35+
}
36+
}
37+
}
38+
39+
template <paddle::DataType D>
40+
std::vector<paddle::Tensor> set_mask_value(const paddle::Tensor& input_data,
41+
const paddle::Tensor& stop_flags,
42+
const paddle::Tensor& seq_lens,
43+
const paddle::Tensor& alibi_slopes,
44+
const paddle::Tensor& tgt_pos
45+
) {
46+
typedef PDTraits<D> traits_;
47+
typedef typename traits_::DataType DataType_;
48+
typedef typename traits_::data_t data_t;
49+
50+
PD_CHECK(seq_lens.dtype() == paddle::DataType::INT32);
51+
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
52+
auto cu_stream = input_data.stream();
53+
std::vector<int64_t> input_data_shape = input_data.shape();
54+
std::vector<int64_t> seq_lens_shape = seq_lens.shape();
55+
auto sequence_lengths = seq_lens.copy_to(seq_lens.place(), false);
56+
57+
int input_bs = input_data_shape[0];
58+
int length = input_data_shape[3];
59+
int seq_bs = seq_lens_shape[0];
60+
int num_head = alibi_slopes.shape()[0];
61+
62+
int grid_size = input_bs;
63+
int block_size = num_head;
64+
set_value_by_id<<<grid_size, block_size, 0, cu_stream>>>(seq_lens.data<int>(),
65+
stop_flags.data<bool>(),
66+
alibi_slopes.data<float>(),
67+
tgt_pos.data<int64_t>(),
68+
reinterpret_cast<DataType_*>(const_cast<data_t*>(input_data.data<data_t>())),
69+
sequence_lengths.data<int>(), seq_bs, length, num_head);
70+
return {sequence_lengths};
71+
}
72+
73+
std::vector<paddle::Tensor> SetMaskValue(const paddle::Tensor& input_data,
74+
const paddle::Tensor& stop_flags,
75+
const paddle::Tensor& seq_lens,
76+
const paddle::Tensor& alibi_slopes,
77+
const paddle::Tensor& tgt_pos) {
78+
switch (input_data.type()) {
79+
case paddle::DataType::BFLOAT16: {
80+
return set_mask_value<paddle::DataType::BFLOAT16>(
81+
input_data,
82+
stop_flags,
83+
seq_lens,
84+
alibi_slopes,
85+
tgt_pos
86+
);
87+
}
88+
case paddle::DataType::FLOAT16: {
89+
return set_mask_value<paddle::DataType::FLOAT16>(
90+
input_data,
91+
stop_flags,
92+
seq_lens,
93+
alibi_slopes,
94+
tgt_pos
95+
);
96+
}
97+
case paddle::DataType::FLOAT32: {
98+
return set_mask_value<paddle::DataType::FLOAT32>(
99+
input_data,
100+
stop_flags,
101+
seq_lens,
102+
alibi_slopes,
103+
tgt_pos
104+
);
105+
}
106+
default: {
107+
PD_THROW(
108+
"NOT supported data type. "
109+
"Only float16, bfloat16 and float32 are supported. ");
110+
break;
111+
}
112+
}
113+
}
114+
115+
std::vector<std::vector<int64_t>> SetMaskValueInferShape(const std::vector<int64_t>& input_data_shape,
116+
const std::vector<int64_t>& stop_flags_shape,
117+
const std::vector<int64_t>& seq_lens_shape,
118+
const std::vector<int64_t>& alibi_slopes_shape,
119+
const std::vector<int64_t>& tgt_pos) {
120+
return {seq_lens_shape};
121+
}
122+
123+
std::vector<paddle::DataType> SetMaskValueInferDtype(const paddle::DataType& input_data_dtype,
124+
const paddle::DataType& stop_flags_dtype,
125+
const paddle::DataType& seq_lens_dtype,
126+
const paddle::DataType& alibi_slopes_dtype,
127+
const paddle::DataType& tgt_pos_dtype) {
128+
return {seq_lens_dtype};
129+
}
130+
131+
PD_BUILD_OP(set_alibi_mask_value)
132+
.Inputs({"input_data", "stop_flags", "seq_lens", "alibi_slopes", "tgt_pos"})
133+
.Outputs({"sequence_lengths"})
134+
.SetKernelFn(PD_KERNEL(SetMaskValue))
135+
.SetInferShapeFn(PD_INFER_SHAPE(SetMaskValueInferShape))
136+
.SetInferDtypeFn(PD_INFER_DTYPE(SetMaskValueInferDtype));

csrc/setup_cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"./generation/write_cache_kv.cu",
3232
"./generation/encode_rotary_qk.cu",
3333
"./generation/top_p_sampling.cu",
34+
"./generation/set_alibi_mask_value.cu",
3435
]
3536
),
3637
)

llm/benchmark.sh

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
export PYTHONPATH=$(dirname $(pwd)):$PYTHONPATH
216

317
export FLAGS_control_flow_use_new_executor=1
@@ -6,10 +20,10 @@ export FLAGS_allocator_strategy=naive_best_fit
620
export FLAGS_fraction_of_gpu_memory_to_use=0.92
721

822
python predictor.py \
9-
--model_name_or_path ./llama-13b-inference_model_fp16 \
23+
--model_name_or_path ./llama7b-inference_model_fp16 \
1024
--dtype float16 \
1125
--src_length 300 \
12-
--max_length 400 \
26+
--max_length 100 \
1327
--output_file "infer.json" \
1428
--mode "static" \
1529
--batch_size 1 \

llm/llama/README.md

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,30 @@
44

55
**支持模型权重:**
66

7-
| Model |
8-
| ------------------------------ |
9-
| facebook/llama-7b |
10-
| facebook/llama-13b |
11-
| facebook/llama-30b |
12-
| facebook/llama-65b |
13-
| meta-llama/Llama-2-7b |
14-
| meta-llama/Llama-2-7b-chat |
15-
| meta-llama/Llama-2-13b |
16-
| meta-llama/Llama-2-13b-chat |
17-
| meta-llama/Llama-2-70b |
18-
| meta-llama/Llama-2-70b-chat |
19-
| ziqingyang/chinese-llama-7b |
20-
| ziqingyang/chinese-llama-13b |
21-
| ziqingyang/chinese-alpaca-7b |
22-
| ziqingyang/chinese-alpaca-13b |
23-
| idea-ccnl/ziya-llama-13b-v1 |
24-
| linly-ai/chinese-llama-2-7b |
25-
| baichuan-inc/Baichuan-7B |
26-
| baichuan-inc/Baichuan-13B-Base |
27-
| baichuan-inc/Baichuan-13B-Chat |
7+
| Model |
8+
| ---------------------------------|
9+
| facebook/llama-7b |
10+
| facebook/llama-13b |
11+
| facebook/llama-30b |
12+
| facebook/llama-65b |
13+
| meta-llama/Llama-2-7b |
14+
| meta-llama/Llama-2-7b-chat |
15+
| meta-llama/Llama-2-13b |
16+
| meta-llama/Llama-2-13b-chat |
17+
| meta-llama/Llama-2-70b |
18+
| meta-llama/Llama-2-70b-chat |
19+
| ziqingyang/chinese-llama-7b |
20+
| ziqingyang/chinese-llama-13b |
21+
| ziqingyang/chinese-alpaca-7b |
22+
| ziqingyang/chinese-alpaca-13b |
23+
| idea-ccnl/ziya-llama-13b-v1 |
24+
| linly-ai/chinese-llama-2-7b |
25+
| baichuan-inc/Baichuan-7B |
26+
| baichuan-inc/Baichuan-13B-Base |
27+
| baichuan-inc/Baichuan-13B-Chat |
28+
| FlagAlpha/Llama2-Chinese-7b-Chat |
29+
| FlagAlpha/Llama2-Chinese-13b-Chat |
30+
2831

2932

3033
使用方法:

0 commit comments

Comments
 (0)