Skip to content

Commit b0bf301

Browse files
authored
Generate: min length can't be larger than max length (#16668)
* min length must be smaller than max length * Update min_length in tests
1 parent 4868a83 commit b0bf301

File tree

4 files changed

+18
-3
lines changed

4 files changed

+18
-3
lines changed

src/transformers/generation_flax_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def generate(
259259
```"""
260260
# set init values
261261
max_length = max_length if max_length is not None else self.config.max_length
262+
min_length = min_length if min_length is not None else self.config.min_length
262263
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
263264
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
264265
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
@@ -269,6 +270,11 @@ def generate(
269270

270271
if decoder_start_token_id is None and self.config.is_encoder_decoder:
271272
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
273+
if min_length is not None and min_length > max_length:
274+
raise ValueError(
275+
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
276+
f"length ({max_length})"
277+
)
272278

273279
if self.config.is_encoder_decoder:
274280
# add encoder_outputs to model_kwargs
@@ -389,7 +395,6 @@ def _get_logits_processor(
389395
no_repeat_ngram_size = (
390396
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
391397
)
392-
min_length = min_length if min_length is not None else self.config.min_length
393398
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
394399
forced_bos_token_id = (
395400
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id

src/transformers/generation_tf_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,11 @@ def _generate(
14891489
if pad_token_id is None and eos_token_id is not None:
14901490
logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence")
14911491
pad_token_id = eos_token_id
1492+
if min_length is not None and min_length > max_length:
1493+
raise ValueError(
1494+
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
1495+
f"length ({max_length})"
1496+
)
14921497

14931498
# 2. Define model inputs
14941499
input_ids = self._prepare_model_inputs(input_ids, bos_token_id)

src/transformers/generation_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,6 @@ def _get_logits_processor(
700700
else self.config.encoder_no_repeat_ngram_size
701701
)
702702
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
703-
min_length = min_length if min_length is not None else self.config.min_length
704703
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
705704
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
706705
forced_bos_token_id = (
@@ -1185,7 +1184,13 @@ def generate(
11851184
)
11861185
# default to config if still None
11871186
max_length = max_length if max_length is not None else self.config.max_length
1187+
min_length = min_length if min_length is not None else self.config.min_length
11881188

1189+
if min_length is not None and min_length > max_length:
1190+
raise ValueError(
1191+
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
1192+
f"length ({max_length})"
1193+
)
11891194
if input_ids_seq_length >= max_length:
11901195
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
11911196
logger.warning(

tests/generation/test_generation_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _get_logits_processor_and_kwargs(
102102
diversity_penalty=None,
103103
):
104104
process_kwargs = {
105-
"min_length": input_length + 1,
105+
"min_length": input_length + 1 if max_length is None else max_length - 1,
106106
"bad_words_ids": [[1, 0]],
107107
"no_repeat_ngram_size": 2,
108108
"repetition_penalty": 1.2,

0 commit comments

Comments
 (0)