Skip to content

Commit 10f5967

Browse files
add input_ids_cpu for proposers to align with inner implementation
1 parent c2febd9 commit 10f5967

File tree

2 files changed

+41
-19
lines changed

2 files changed

+41
-19
lines changed

llm/predict/predictor.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from env import MAX_BSZ, MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ
2828
from paddle.base.framework import in_cinn_mode, in_pir_executor_mode, use_pir_api
2929
from paddle.distributed import fleet
30+
from paddlenlp_ops import speculate_update_input_ids_cpu
3031
from proposers import InferenceWithReferenceProposer
3132

3233
from paddlenlp.generation import GenerationConfig, TextIteratorStreamer
@@ -103,7 +104,7 @@ class PredictorArgument:
103104
default="fp16",
104105
metadata={"help": "avx cachekv type. Supported values: fp16,int8"},
105106
)
106-
batch_size: int = field(default=1, metadata={"help": "The batch size of data."})
107+
batch_size: int = field(default=10, metadata={"help": "The batch size of data."})
107108
benchmark: bool = field(
108109
default=False,
109110
metadata={
@@ -964,6 +965,8 @@ def _preprocess(self, input_text: list[str]):
964965
for k, v in self.model_inputs.items():
965966
v.name = k
966967

968+
return seq_lens
969+
967970

968971
class DygraphBlockInferencePredictor(BlockInferencePredictorMixin):
969972
def __init__(
@@ -1187,7 +1190,7 @@ def __init__(
11871190

11881191
@paddle.no_grad()
11891192
def predict(self, input_texts: list[str], return_tokens=False):
1190-
self._preprocess(input_texts)
1193+
self.seq_lens = self._preprocess(input_texts)
11911194

11921195
# Parameters such as seq_lens_encoder have been set in the preprocessor function,
11931196
# then we use them to init the proposer's args
@@ -1238,10 +1241,6 @@ def predict(self, input_texts: list[str], return_tokens=False):
12381241
return outputs
12391242

12401243
def init_proposer_args(self):
1241-
for bid in range(self.config.batch_size):
1242-
self.model_inputs["pre_ids"][bid, 0] = self.model_inputs["input_ids"][bid][
1243-
self.model_inputs["seq_lens_this_time"][bid] - 1
1244-
] # get the last token before padding of this batch
12451244
self.model_inputs["accept_tokens"] = paddle.full(
12461245
shape=[self.config.batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
12471246
)
@@ -1252,9 +1251,23 @@ def init_proposer_args(self):
12521251
self.model_inputs["actual_draft_token_num"] = paddle.full(
12531252
shape=[self.config.batch_size], fill_value=self.config.speculate_max_draft_token_num, dtype="int32"
12541253
)
1255-
if self.config.speculate_method == "inference_with_reference":
1256-
self.proposer.input_ids_cpu = self.model_inputs["input_ids"].cpu()
1257-
self.proposer.input_ids_len = self.model_inputs["seq_lens_encoder"].astype("int64").cpu()
1254+
self.model_inputs["input_ids_cpu"] = paddle.full(
1255+
shape=[self.config.batch_size, self.config.total_max_length], fill_value=1, dtype="int64"
1256+
).cpu()
1257+
1258+
for bid in range(self.config.batch_size):
1259+
if self.config.speculate_method == "inference_with_reference":
1260+
speculate_update_input_ids_cpu(
1261+
self.model_inputs["input_ids_cpu"],
1262+
self.model_inputs["input_ids"][bid].cpu().tolist(),
1263+
bid,
1264+
self.config.max_length,
1265+
)
1266+
self.proposer.update(bid, self.seq_lens[bid])
1267+
1268+
self.model_inputs["pre_ids"][bid, 0] = self.model_inputs["input_ids"][bid][
1269+
self.model_inputs["seq_lens_this_time"][bid] - 1
1270+
] # get the last token before padding of this batch
12581271

12591272

12601273
class StaticSpeculateInferencePredictor(StaticBlockInferencePredictor):
@@ -1277,11 +1290,6 @@ def __init__(
12771290
self.proposer = None
12781291

12791292
def init_proposer_args(self):
1280-
for bid in range(self.config.batch_size):
1281-
self.model_inputs["pre_ids"][bid, 0] = self.model_inputs["input_ids"][bid][
1282-
self.model_inputs["seq_lens_this_time"][bid] - 1
1283-
] # get the last token before padding of this batch
1284-
12851293
self.model_inputs["accept_tokens"] = paddle.full(
12861294
shape=[self.config.batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
12871295
)
@@ -1292,12 +1300,27 @@ def init_proposer_args(self):
12921300
self.model_inputs["actual_draft_token_num"] = paddle.full(
12931301
shape=[self.config.batch_size], fill_value=self.config.speculate_max_draft_token_num, dtype="int32"
12941302
)
1295-
if self.config.speculate_method == "inference_with_reference":
1296-
self.proposer.input_ids_cpu = self.model_inputs["input_ids"].cpu()
1303+
self.model_inputs["input_ids_cpu"] = paddle.full(
1304+
shape=[self.config.batch_size, self.config.total_max_length], fill_value=1, dtype="int64"
1305+
).cpu()
1306+
1307+
for bid in range(self.config.batch_size):
1308+
if self.config.speculate_method == "inference_with_reference":
1309+
speculate_update_input_ids_cpu(
1310+
self.model_inputs["input_ids_cpu"],
1311+
self.model_inputs["input_ids"][bid].cpu().tolist(),
1312+
bid,
1313+
self.config.max_length,
1314+
)
1315+
self.proposer.update(bid, self.seq_lens[bid])
1316+
1317+
self.model_inputs["pre_ids"][bid, 0] = self.model_inputs["input_ids"][bid][
1318+
self.model_inputs["seq_lens_this_time"][bid] - 1
1319+
] # get the last token before padding of this batch
12971320

12981321
def predict(self, input_texts: list[str], return_tokens=False):
12991322
s_time = time.time()
1300-
self._preprocess(input_texts)
1323+
self.seq_lens = self._preprocess(input_texts)
13011324

13021325
# Parameters such as seq_lens_encoder have been set in the preprocessor function,
13031326
# then we use them to init the proposer's args

llm/predict/proposers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size
6262
self.input_ids_len = paddle.zeros(shape=[max_batch_size, 1], dtype="int64").cpu()
6363
self.max_batch_size = max_batch_size
6464
self.max_draft_token_num = max_draft_token_num
65-
self.input_ids_cpu = paddle.full(shape=[max_batch_size, max_seq_len], fill_value=1, dtype="int64").cpu()
6665

6766
def update(self, bid: int, seq_len: int):
6867
"""
@@ -79,7 +78,7 @@ def run(self, model_inputs: dict[str, paddle.Tensor], **kargs):
7978
seq_lens_encoder = model_inputs["seq_lens_encoder"].cpu()
8079
seq_lens_decoder = model_inputs["seq_lens_decoder"].cpu()
8180
ngram_match(
82-
self.input_ids_cpu,
81+
model_inputs["input_ids_cpu"],
8382
self.input_ids_len.cpu(),
8483
model_inputs["pre_ids"].cpu(),
8584
model_inputs["step_idx"].cpu(),

0 commit comments

Comments
 (0)