Skip to content

Commit c2febd9

Browse files
update
1 parent b4fedab commit c2febd9

File tree

4 files changed

+17
-23
lines changed

4 files changed

+17
-23
lines changed

llm/speculate_decoding/__init__.py renamed to llm/predict/env.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,10 @@
1010
# distributed under the License is distributed on an "AS IS" BASIS,
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
13-
# limitations under the License.
13+
# limitations under the License.
14+
15+
# Note(@RochardWooSJTU): MAX_BSZ must be the same as definition in get_output / save_output
16+
MAX_BSZ = 512
17+
# Note(@Wanglongzhi2001): SPECULATE_MAX_BSZ must be the same as definition in speculate_get_output / speculate_save_output
18+
SPECULATE_MAX_BSZ = 256
19+
MAX_DRAFT_TOKENS = 6

llm/predict/predictor.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
import numpy as np
2525
import paddle
2626
import paddle.incubate.multiprocessing as mp
27+
from env import MAX_BSZ, MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ
2728
from paddle.base.framework import in_cinn_mode, in_pir_executor_mode, use_pir_api
2829
from paddle.distributed import fleet
30+
from proposers import InferenceWithReferenceProposer
2931

30-
from llm.speculate_decoding.proposers import InferenceWithReferenceProposer
3132
from paddlenlp.generation import GenerationConfig, TextIteratorStreamer
3233
from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM
3334
from paddlenlp.taskflow.utils import static_mode_guard
@@ -47,12 +48,6 @@
4748
from paddlenlp.utils.import_utils import is_paddlenlp_ops_available
4849
from paddlenlp.utils.log import logger
4950

50-
# Note(@RochardWooSJTU): MAX_BSZ must be the same as definition in get_output / save_output
51-
MAX_BSZ = 512
52-
# Note(@Wanglongzhi2001): SPECULATE_MAX_BSZ must be the same as definition in speculate_get_output / speculate_save_output
53-
SPECULATE_MAX_BSZ = 256
54-
MAX_DRAFT_TOKENS = 6
55-
5651

5752
@dataclass
5853
class PredictorArgument:
@@ -108,7 +103,7 @@ class PredictorArgument:
108103
default="fp16",
109104
metadata={"help": "avx cachekv type. Supported values: fp16,int8"},
110105
)
111-
batch_size: int = field(default=10, metadata={"help": "The batch size of data."})
106+
batch_size: int = field(default=1, metadata={"help": "The batch size of data."})
112107
benchmark: bool = field(
113108
default=False,
114109
metadata={
@@ -1242,15 +1237,11 @@ def predict(self, input_texts: list[str], return_tokens=False):
12421237
else:
12431238
return outputs
12441239

1245-
def _preprocess(self, input_text: list[str]):
1246-
super()._preprocess(input_text)
1247-
1240+
def init_proposer_args(self):
12481241
for bid in range(self.config.batch_size):
12491242
self.model_inputs["pre_ids"][bid, 0] = self.model_inputs["input_ids"][bid][
12501243
self.model_inputs["seq_lens_this_time"][bid] - 1
12511244
] # get the last token before padding of this batch
1252-
1253-
def init_proposer_args(self):
12541245
self.model_inputs["accept_tokens"] = paddle.full(
12551246
shape=[self.config.batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
12561247
)
@@ -1286,6 +1277,11 @@ def __init__(
12861277
self.proposer = None
12871278

12881279
def init_proposer_args(self):
1280+
for bid in range(self.config.batch_size):
1281+
self.model_inputs["pre_ids"][bid, 0] = self.model_inputs["input_ids"][bid][
1282+
self.model_inputs["seq_lens_this_time"][bid] - 1
1283+
] # get the last token before padding of this batch
1284+
12891285
self.model_inputs["accept_tokens"] = paddle.full(
12901286
shape=[self.config.batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
12911287
)
@@ -1299,14 +1295,6 @@ def init_proposer_args(self):
12991295
if self.config.speculate_method == "inference_with_reference":
13001296
self.proposer.input_ids_cpu = self.model_inputs["input_ids"].cpu()
13011297

1302-
def _preprocess(self, input_text: list[str]):
1303-
super()._preprocess(input_text)
1304-
1305-
for bid in range(self.config.batch_size):
1306-
self.model_inputs["pre_ids"][bid, 0] = self.model_inputs["input_ids"][bid][
1307-
self.model_inputs["seq_lens_this_time"][bid] - 1
1308-
] # get the last token before padding of this batch
1309-
13101298
def predict(self, input_texts: list[str], return_tokens=False):
13111299
s_time = time.time()
13121300
self._preprocess(input_texts)
File renamed without changes.

paddlenlp/utils/llm_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
2929
from sklearn.metrics import accuracy_score
3030

31-
from llm.predict.predictor import MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ
31+
from llm.predict.env import MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ
3232
from paddlenlp.datasets import ZeroPaddingIterableDataset
3333
from paddlenlp.generation import GenerationConfig
3434
from paddlenlp.trainer import Trainer, TrainerCallback

0 commit comments

Comments
 (0)