Skip to content

Commit b4fedab

Browse files
update
1 parent 42a90d1 commit b4fedab

File tree

7 files changed

+254
-250
lines changed

7 files changed

+254
-250
lines changed

csrc/gpu/speculate_decoding_kernels/speculate_save_output.cc

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <sys/types.h>
2020
#include "paddle/extension.h"
2121

22-
#define MAX_BSZ 512
22+
#define MAX_BSZ 256
2323
#define MAX_DRAFT_TOKENS 6
2424

2525
struct msgdata {
@@ -31,7 +31,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
3131
const paddle::Tensor& accept_num,
3232
const paddle::Tensor& not_need_stop,
3333
int64_t rank_id,
34-
const int msg_queue_id) {
34+
const int msg_queue_id) {
3535
if (rank_id > 0) return;
3636

3737
int max_draft_tokens = accept_tokens.shape()[1];
@@ -71,7 +71,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
7171
}
7272
}
7373
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4, 0)) == -1) {
74-
printf("full msg buffer\n");
74+
printf("full msg buffer\n");
7575
}
7676
return;
7777
}
@@ -98,10 +98,3 @@ PD_BUILD_OP(speculate_save_output)
9898
.Outputs({"x_out"})
9999
.SetInplaceMap({{"accept_tokens", "x_out"}})
100100
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgStatic));
101-
102-
PD_BUILD_OP(speculate_save_output_dynamic)
103-
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
104-
.Attrs({"rank_id: int64_t", "msg_queue_id: int"})
105-
.Outputs({"x_out"})
106-
.SetInplaceMap({{"accept_tokens", "x_out"}})
107-
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgDynamic));

csrc/gpu/speculate_decoding_kernels/speculate_step.cu

Lines changed: 157 additions & 190 deletions
Large diffs are not rendered by default.

llm/predict/predictor.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from paddle.base.framework import in_cinn_mode, in_pir_executor_mode, use_pir_api
2828
from paddle.distributed import fleet
2929

30-
from llm.speculate_decoding.proposer import InferenceWithReferenceProposer
30+
from llm.speculate_decoding.proposers import InferenceWithReferenceProposer
3131
from paddlenlp.generation import GenerationConfig, TextIteratorStreamer
3232
from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM
3333
from paddlenlp.taskflow.utils import static_mode_guard
@@ -49,6 +49,8 @@
4949

5050
# Note(@RochardWooSJTU): MAX_BSZ must be the same as definition in get_output / save_output
5151
MAX_BSZ = 512
52+
# Note(@Wanglongzhi2001): SPECULATE_MAX_BSZ must be the same as definition in speculate_get_output / speculate_save_output
53+
SPECULATE_MAX_BSZ = 256
5254
MAX_DRAFT_TOKENS = 6
5355

5456

@@ -106,7 +108,7 @@ class PredictorArgument:
106108
default="fp16",
107109
metadata={"help": "avx cachekv type. Supported values: fp16,int8"},
108110
)
109-
batch_size: int = field(default=1, metadata={"help": "The batch size of data."})
111+
batch_size: int = field(default=10, metadata={"help": "The batch size of data."})
110112
benchmark: bool = field(
111113
default=False,
112114
metadata={
@@ -142,7 +144,7 @@ class PredictorArgument:
142144
"help": "speculate method, it should be one of ['None', 'autoregressive', 'inference_with_reference']"
143145
},
144146
)
145-
speculate_max_draft_tokens: int = field(
147+
speculate_max_draft_token_num: int = field(
146148
default=1,
147149
metadata={"help": "the max length of draft tokens for speculate method."},
148150
)
@@ -1180,7 +1182,7 @@ def __init__(
11801182
# init speculate components
11811183
if config.speculate_method == "inference_with_reference":
11821184
self.proposer = InferenceWithReferenceProposer(
1183-
config.speculate_max_draft_tokens,
1185+
config.speculate_max_draft_token_num,
11841186
config.speculate_max_ngram_size,
11851187
config.batch_size,
11861188
config.max_length,
@@ -1192,7 +1194,7 @@ def __init__(
11921194
def predict(self, input_texts: list[str], return_tokens=False):
11931195
self._preprocess(input_texts)
11941196

1195-
# Parameters such as seq_lens_encoder have been set in the preprocessor function,
1197+
# Parameters such as seq_lens_encoder have been set in the preprocessor function,
11961198
# then we use them to init the proposer's args
11971199
self.init_proposer_args()
11981200

@@ -1206,7 +1208,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
12061208
read_res_process.start()
12071209

12081210
output_tensor = paddle.full(
1209-
shape=[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2, 1], fill_value=2, dtype="int64"
1211+
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1], fill_value=2, dtype="int64"
12101212
).cpu()
12111213
tensor_queue.put(output_tensor)
12121214
if self.tensor_parallel_rank == 0:
@@ -1250,14 +1252,14 @@ def _preprocess(self, input_text: list[str]):
12501252

12511253
def init_proposer_args(self):
12521254
self.model_inputs["accept_tokens"] = paddle.full(
1253-
shape=[self.config.batch_size, self.config.speculate_max_draft_tokens + 1], fill_value=0, dtype="int64"
1255+
shape=[self.config.batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
12541256
)
12551257
self.model_inputs["accept_num"] = paddle.full(shape=[self.config.batch_size], fill_value=0, dtype="int32")
12561258
self.model_inputs["draft_tokens"] = paddle.full(
1257-
shape=[self.config.batch_size, self.config.speculate_max_draft_tokens + 1], fill_value=0, dtype="int64"
1259+
shape=[self.config.batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
12581260
)
12591261
self.model_inputs["actual_draft_token_num"] = paddle.full(
1260-
shape=[self.config.batch_size], fill_value=self.config.speculate_max_draft_tokens, dtype="int32"
1262+
shape=[self.config.batch_size], fill_value=self.config.speculate_max_draft_token_num, dtype="int32"
12611263
)
12621264
if self.config.speculate_method == "inference_with_reference":
12631265
self.proposer.input_ids_cpu = self.model_inputs["input_ids"].cpu()
@@ -1275,7 +1277,7 @@ def __init__(
12751277
# init speculate components
12761278
if config.speculate_method == "inference_with_reference":
12771279
self.proposer = InferenceWithReferenceProposer(
1278-
config.speculate_max_draft_tokens,
1280+
config.speculate_max_draft_token_num,
12791281
config.speculate_max_ngram_size,
12801282
config.batch_size,
12811283
config.max_length,
@@ -1285,14 +1287,14 @@ def __init__(
12851287

12861288
def init_proposer_args(self):
12871289
self.model_inputs["accept_tokens"] = paddle.full(
1288-
shape=[self.config.batch_size, self.config.speculate_max_draft_tokens + 1], fill_value=0, dtype="int64"
1290+
shape=[self.config.batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
12891291
)
12901292
self.model_inputs["accept_num"] = paddle.full(shape=[self.config.batch_size], fill_value=0, dtype="int32")
12911293
self.model_inputs["draft_tokens"] = paddle.full(
1292-
shape=[self.config.batch_size, self.config.speculate_max_draft_tokens + 1], fill_value=0, dtype="int64"
1294+
shape=[self.config.batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
12931295
)
12941296
self.model_inputs["actual_draft_token_num"] = paddle.full(
1295-
shape=[self.config.batch_size], fill_value=self.config.speculate_max_draft_tokens, dtype="int32"
1297+
shape=[self.config.batch_size], fill_value=self.config.speculate_max_draft_token_num, dtype="int32"
12961298
)
12971299
if self.config.speculate_method == "inference_with_reference":
12981300
self.proposer.input_ids_cpu = self.model_inputs["input_ids"].cpu()
@@ -1309,7 +1311,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
13091311
s_time = time.time()
13101312
self._preprocess(input_texts)
13111313

1312-
# Parameters such as seq_lens_encoder have been set in the preprocessor function,
1314+
# Parameters such as seq_lens_encoder have been set in the preprocessor function,
13131315
# then we use them to init the proposer's args
13141316
self.init_proposer_args()
13151317
logger.info(f"preprocess spend {time.time() - s_time}")
@@ -1332,7 +1334,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
13321334
if self.tensor_parallel_rank == 0:
13331335
read_res_process.start()
13341336
output_tensor = paddle.full(
1335-
shape=[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2, 1], fill_value=2, dtype="int64"
1337+
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1], fill_value=2, dtype="int64"
13361338
).cpu()
13371339
tensor_queue.put(output_tensor)
13381340
if self.tensor_parallel_rank == 0:
@@ -1505,7 +1507,7 @@ def create_predictor(
15051507
elif predictor_args.speculate_method is not None:
15061508
config.max_seq_len = predictor_args.total_max_length
15071509
config.block_size = predictor_args.block_size
1508-
config.speculate_max_draft_tokens = predictor_args.speculate_max_draft_tokens
1510+
config.speculate_max_draft_token_num = predictor_args.speculate_max_draft_token_num
15091511
config.speculate_max_ngram_size = predictor_args.speculate_max_ngram_size
15101512
config.speculate_verify_window = predictor_args.speculate_verify_window
15111513
config.speculate_max_candidate_len = predictor_args.speculate_max_candidate_len
@@ -1738,7 +1740,7 @@ def create_predictor(
17381740
elif predictor_args.speculate_method is not None:
17391741
config.max_seq_len = predictor_args.total_max_length
17401742
config.block_size = predictor_args.block_size
1741-
config.speculate_max_draft_tokens = predictor_args.speculate_max_draft_tokens
1743+
config.speculate_max_draft_token_num = predictor_args.speculate_max_draft_token_num
17421744
config.speculate_max_ngram_size = predictor_args.speculate_max_ngram_size
17431745
config.speculate_verify_window = predictor_args.speculate_verify_window
17441746
config.speculate_max_candidate_len = predictor_args.speculate_max_candidate_len

llm/speculate_decoding/proposer.py renamed to llm/speculate_decoding/proposers.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
from abc import ABC, abstractmethod
16-
from __future__ import annotations
1717

1818
import paddle
1919
from paddlenlp_ops import ngram_match
@@ -43,10 +43,10 @@ class InferenceWithReferenceProposer(Proposer):
4343
It match tokens in the input and output as draft tokens.
4444
"""
4545

46-
def __init__(self, max_draft_tokens: int, max_ngram_size: int, max_batch_size: int, max_seq_len: int, **kwargs):
46+
def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size: int, max_seq_len: int, **kwargs):
4747
"""
4848
Args:
49-
max_draft_tokens (int):
49+
max_draft_token_num (int):
5050
Maximum number of tokens a proposer can generate at one time.
5151
The hyperparameter of k in the paper.
5252
max_ngram_size (int):
@@ -61,9 +61,15 @@ def __init__(self, max_draft_tokens: int, max_ngram_size: int, max_batch_size: i
6161
self.max_ngram_size = max_ngram_size
6262
self.input_ids_len = paddle.zeros(shape=[max_batch_size, 1], dtype="int64").cpu()
6363
self.max_batch_size = max_batch_size
64-
self.max_draft_tokens = max_draft_tokens
64+
self.max_draft_token_num = max_draft_token_num
6565
self.input_ids_cpu = paddle.full(shape=[max_batch_size, max_seq_len], fill_value=1, dtype="int64").cpu()
6666

67+
def update(self, bid: int, seq_len: int):
68+
"""
69+
Used when inserting a new query to update the length of the input_ids.
70+
"""
71+
self.input_ids_len[bid] = seq_len
72+
6773
def run(self, model_inputs: dict[str, paddle.Tensor], **kargs):
6874
"""
6975
Use ngram_match to get draft tokens from the input and output.
@@ -84,7 +90,7 @@ def run(self, model_inputs: dict[str, paddle.Tensor], **kargs):
8490
seq_lens_decoder,
8591
kargs["real_batch_size"],
8692
self.max_ngram_size,
87-
self.max_draft_tokens,
93+
self.max_draft_token_num,
8894
)
8995

9096
model_inputs["draft_tokens"][:] = draft_tokens.cuda()

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
649649

650650
self.linear = fused_linear
651651

652+
# used in speculative decoding, if speculate_max_draft_token_num is 1
653+
# and speculate_method is None, it will be autogressive decoding.
654+
self.speculate_max_draft_token_num = 1
655+
self.speculate_method = None
656+
652657
def init_weight(self):
653658
self.qkv_weights = []
654659
self.linear_weights = []
@@ -1095,7 +1100,6 @@ def forward(
10951100
kwargs["decoder_block_shape_q"] = 16
10961101
kwargs["max_partition_size"] = 32768
10971102
kwargs["encoder_max_partition_size"] = 32768
1098-
kwargs["speculate_max_draft_token_num"] = 5
10991103

11001104
from paddlenlp_ops import get_block_shape_and_split_kv_block
11011105

@@ -1120,7 +1124,7 @@ def forward(
11201124
kwargs.get("decoder_block_shape_q", 16),
11211125
self.num_heads // self.kv_num_heads,
11221126
kwargs.get("block_size", 64),
1123-
kwargs["speculate_max_draft_token_num"],
1127+
self.speculate_max_draft_token_num,
11241128
)
11251129

11261130
residual_input = src
@@ -2259,9 +2263,9 @@ def compute_attn(
22592263
kwargs.get("decoder_block_shape_q", 16),
22602264
kwargs.get("max_partition_size", 32768),
22612265
kwargs.get("encoder_max_partition_size", 32768),
2262-
kwargs["speculate_max_draft_token_num"], # speculate_max_draft_token_num
2266+
self.speculate_max_draft_token_num, # speculate_max_draft_token_num
22632267
True, # causal
2264-
False, # speculate_decoder
2268+
self.speculate_method is not None, # speculate_decoder
22652269
)[0]
22662270
else:
22672271
if core.is_compiled_with_xpu():
@@ -2441,9 +2445,9 @@ def compute_attn(
24412445
kwargs.get("decoder_block_shape_q", 16),
24422446
kwargs.get("max_partition_size", 32768),
24432447
kwargs.get("encoder_max_partition_size", 32768),
2444-
kwargs["speculate_max_draft_token_num"], # speculate_max_draft_token_num
2448+
self.speculate_max_draft_token_num, # speculate_max_draft_token_num
24452449
True, # causal
2446-
False, # speculate_decoder
2450+
self.speculate_method is not None, # speculate_decoder
24472451
)[0]
24482452
else:
24492453
fmha_out = paddle.incubate.nn.functional.block_multihead_attention(
@@ -3258,7 +3262,17 @@ def forward(
32583262
return out, caches
32593263

32603264

3261-
class FusedSpeculateMultiTransformer(FusedAppendMultiTransformer):
3265+
class FusedSpeculateMultiTransformer(FusedBlockMultiTransformer):
3266+
def __init__(
3267+
self,
3268+
speculate_max_draft_token_num: int,
3269+
speculate_method: str = None,
3270+
config: FusedMultiTransformerConfig = None,
3271+
):
3272+
super().__init__(config)
3273+
self.speculate_max_draft_token_num = speculate_max_draft_token_num
3274+
self.speculate_method = speculate_method
3275+
32623276
def post_process(self, **kwargs):
32633277
embed_dim = self.config.embed_dim
32643278
multi_block_output = kwargs.get("multi_block_output", None)
@@ -3279,12 +3293,18 @@ def post_process(self, **kwargs):
32793293
return out
32803294

32813295

3282-
class FusedSpeculateMultiTransformerA8W8(FusedAppendMultiTransformerA8W8):
3283-
def __init__(self, config: FusedMultiTransformerConfig):
3296+
class FusedSpeculateMultiTransformerA8W8(FusedBlockMultiTransformerA8W8):
3297+
def __init__(
3298+
self,
3299+
speculate_max_draft_token_num: int,
3300+
speculate_method: str = None,
3301+
config: FusedMultiTransformerConfig = None,
3302+
):
32843303
super().__init__(config)
3304+
self.speculate_max_draft_token_num = speculate_max_draft_token_num
3305+
self.speculate_method = speculate_method
32853306

32863307
def post_process(self, **kwargs):
3287-
logger.info("use FusedSpeculateMultiTransformerA8W8")
32883308
embed_dim = self.config.embed_dim
32893309
multi_block_output = kwargs.get("multi_block_output", None)
32903310
cum_offsets = kwargs.get("cum_offsets", None)

0 commit comments

Comments
 (0)