File tree Expand file tree Collapse file tree 4 files changed +18
-3
lines changed Expand file tree Collapse file tree 4 files changed +18
-3
lines changed Original file line number Diff line number Diff line change @@ -259,6 +259,7 @@ def generate(
259
259
```"""
260
260
# set init values
261
261
max_length = max_length if max_length is not None else self .config .max_length
262
+ min_length = min_length if min_length is not None else self .config .min_length
262
263
bos_token_id = bos_token_id if bos_token_id is not None else self .config .bos_token_id
263
264
pad_token_id = pad_token_id if pad_token_id is not None else self .config .pad_token_id
264
265
eos_token_id = eos_token_id if eos_token_id is not None else self .config .eos_token_id
@@ -269,6 +270,11 @@ def generate(
269
270
270
271
if decoder_start_token_id is None and self .config .is_encoder_decoder :
271
272
raise ValueError ("`decoder_start_token_id` has to be defined for encoder-decoder generation." )
273
+ if min_length is not None and min_length > max_length :
274
+ raise ValueError (
275
+ f"Unfeasable length constraints: the minimum length ({ min_length } ) is larger than the maximum "
276
+ f"length ({ max_length } )"
277
+ )
272
278
273
279
if self .config .is_encoder_decoder :
274
280
# add encoder_outputs to model_kwargs
@@ -389,7 +395,6 @@ def _get_logits_processor(
389
395
no_repeat_ngram_size = (
390
396
no_repeat_ngram_size if no_repeat_ngram_size is not None else self .config .no_repeat_ngram_size
391
397
)
392
- min_length = min_length if min_length is not None else self .config .min_length
393
398
eos_token_id = eos_token_id if eos_token_id is not None else self .config .eos_token_id
394
399
forced_bos_token_id = (
395
400
forced_bos_token_id if forced_bos_token_id is not None else self .config .forced_bos_token_id
Original file line number Diff line number Diff line change @@ -1489,6 +1489,11 @@ def _generate(
1489
1489
if pad_token_id is None and eos_token_id is not None :
1490
1490
logger .warning (f"Setting `pad_token_id` to { eos_token_id } (first `eos_token_id`) to generate sequence" )
1491
1491
pad_token_id = eos_token_id
1492
+ if min_length is not None and min_length > max_length :
1493
+ raise ValueError (
1494
+ f"Unfeasable length constraints: the minimum length ({ min_length } ) is larger than the maximum "
1495
+ f"length ({ max_length } )"
1496
+ )
1492
1497
1493
1498
# 2. Define model inputs
1494
1499
input_ids = self ._prepare_model_inputs (input_ids , bos_token_id )
Original file line number Diff line number Diff line change @@ -700,7 +700,6 @@ def _get_logits_processor(
700
700
else self .config .encoder_no_repeat_ngram_size
701
701
)
702
702
bad_words_ids = bad_words_ids if bad_words_ids is not None else self .config .bad_words_ids
703
- min_length = min_length if min_length is not None else self .config .min_length
704
703
eos_token_id = eos_token_id if eos_token_id is not None else self .config .eos_token_id
705
704
diversity_penalty = diversity_penalty if diversity_penalty is not None else self .config .diversity_penalty
706
705
forced_bos_token_id = (
@@ -1185,7 +1184,13 @@ def generate(
1185
1184
)
1186
1185
# default to config if still None
1187
1186
max_length = max_length if max_length is not None else self .config .max_length
1187
+ min_length = min_length if min_length is not None else self .config .min_length
1188
1188
1189
+ if min_length is not None and min_length > max_length :
1190
+ raise ValueError (
1191
+ f"Unfeasable length constraints: the minimum length ({ min_length } ) is larger than the maximum "
1192
+ f"length ({ max_length } )"
1193
+ )
1189
1194
if input_ids_seq_length >= max_length :
1190
1195
input_ids_string = "decoder_input_ids" if self .config .is_encoder_decoder else "input_ids"
1191
1196
logger .warning (
Original file line number Diff line number Diff line change @@ -102,7 +102,7 @@ def _get_logits_processor_and_kwargs(
102
102
diversity_penalty = None ,
103
103
):
104
104
process_kwargs = {
105
- "min_length" : input_length + 1 ,
105
+ "min_length" : input_length + 1 if max_length is None else max_length - 1 ,
106
106
"bad_words_ids" : [[1 , 0 ]],
107
107
"no_repeat_ngram_size" : 2 ,
108
108
"repetition_penalty" : 1.2 ,
You can’t perform that action at this time.
0 commit comments