Skip to content

Commit 87f5510

Browse files
committed
fix ci
1 parent ca2aedb commit 87f5510

File tree

2 files changed

+10
-31
lines changed

2 files changed

+10
-31
lines changed

llm/predict/predictor.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,15 +1200,17 @@ def create_predictor(
12001200

12011201
max_position_embeddings = get_model_max_position_embeddings(config)
12021202
if max_position_embeddings is None:
1203-
max_position_embeddings = 2048
1204-
logger.warning("Can not retrieval `max_position_embeddings` from config.json, use default value 2048")
1205-
1206-
if predictor_args.src_length + predictor_args.max_length > max_position_embeddings:
1207-
raise ValueError(
1208-
f"The sum of src_length<{predictor_args.src_length}> and "
1209-
f"max_length<{predictor_args.max_length}> should be smaller than or equal to "
1210-
f"the maximum position embedding size<{max_position_embeddings}>"
1203+
max_position_embeddings = predictor_args.src_length + predictor_args.max_length
1204+
logger.warning(
1205+
f"Can not retrieval `max_position_embeddings` from config.json, use default value {max_position_embeddings}"
12111206
)
1207+
else:
1208+
if predictor_args.src_length + predictor_args.max_length > max_position_embeddings:
1209+
raise ValueError(
1210+
f"The sum of src_length<{predictor_args.src_length}> and "
1211+
f"max_length<{predictor_args.max_length}> should be smaller than or equal to "
1212+
f"the maximum position embedding size<{max_position_embeddings}>"
1213+
)
12121214

12131215
# update config parameter for inference predictor
12141216
if predictor_args.decode_strategy == "greedy_search":

tests/llm/test_predictor.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import unittest
1818

1919
import paddle
20-
import pytest
2120
from parameterized import parameterized_class
2221

2322
from paddlenlp.experimental.transformers import QWenForQWenVLInferenceModel
@@ -199,28 +198,6 @@ def test_predictor(self):
199198
self.assertGreaterEqual(count / len(result_0), 0.8)
200199

201200

202-
class PredictorBaseTest(LLMTest, unittest.TestCase):
203-
def load_test_config(self):
204-
config = load_test_config("./tests/fixtures/llm/predictor.yaml", "inference-predict")
205-
config["model_name_or_path"] = "__internal_testing__/micro-random-llama"
206-
207-
return config
208-
209-
def test_create_predictor_with_unexpected_length(self):
210-
from predict.predictor import predict
211-
212-
config = self.load_test_config()
213-
config.pop("src_length", None)
214-
config.pop("max_length", None)
215-
216-
with pytest.raises(ValueError, match="The sum of src_length<1025> and"):
217-
config["max_length"] = 1024
218-
config["src_length"] = 1025
219-
220-
with argv_context_guard(config):
221-
predict()
222-
223-
224201
@parameterized_class(
225202
["model_name_or_path", "model_class"],
226203
[

0 commit comments

Comments
 (0)