File tree Expand file tree Collapse file tree 2 files changed +10
-31
lines changed Expand file tree Collapse file tree 2 files changed +10
-31
lines changed Original file line number Diff line number Diff line change @@ -1200,15 +1200,17 @@ def create_predictor(
12001200
12011201 max_position_embeddings = get_model_max_position_embeddings (config )
12021202 if max_position_embeddings is None :
1203- max_position_embeddings = 2048
1204- logger .warning ("Can not retrieval `max_position_embeddings` from config.json, use default value 2048" )
1205-
1206- if predictor_args .src_length + predictor_args .max_length > max_position_embeddings :
1207- raise ValueError (
1208- f"The sum of src_length<{ predictor_args .src_length } > and "
1209- f"max_length<{ predictor_args .max_length } > should be smaller than or equal to "
1210- f"the maximum position embedding size<{ max_position_embeddings } >"
1203+ max_position_embeddings = predictor_args .src_length + predictor_args .max_length
1204+ logger .warning (
1205+ f"Can not retrieval `max_position_embeddings` from config.json, use default value { max_position_embeddings } "
12111206 )
1207+ else :
1208+ if predictor_args .src_length + predictor_args .max_length > max_position_embeddings :
1209+ raise ValueError (
1210+ f"The sum of src_length<{ predictor_args .src_length } > and "
1211+ f"max_length<{ predictor_args .max_length } > should be smaller than or equal to "
1212+ f"the maximum position embedding size<{ max_position_embeddings } >"
1213+ )
12121214
12131215 # update config parameter for inference predictor
12141216 if predictor_args .decode_strategy == "greedy_search" :
Original file line number Diff line number Diff line change 1717import unittest
1818
1919import paddle
20- import pytest
2120from parameterized import parameterized_class
2221
2322from paddlenlp .experimental .transformers import QWenForQWenVLInferenceModel
@@ -199,28 +198,6 @@ def test_predictor(self):
199198 self .assertGreaterEqual (count / len (result_0 ), 0.8 )
200199
201200
202- class PredictorBaseTest (LLMTest , unittest .TestCase ):
203- def load_test_config (self ):
204- config = load_test_config ("./tests/fixtures/llm/predictor.yaml" , "inference-predict" )
205- config ["model_name_or_path" ] = "__internal_testing__/micro-random-llama"
206-
207- return config
208-
209- def test_create_predictor_with_unexpected_length (self ):
210- from predict .predictor import predict
211-
212- config = self .load_test_config ()
213- config .pop ("src_length" , None )
214- config .pop ("max_length" , None )
215-
216- with pytest .raises (ValueError , match = "The sum of src_length<1025> and" ):
217- config ["max_length" ] = 1024
218- config ["src_length" ] = 1025
219-
220- with argv_context_guard (config ):
221- predict ()
222-
223-
224201@parameterized_class (
225202 ["model_name_or_path" , "model_class" ],
226203 [
You can’t perform that action at this time.
0 commit comments