Skip to content

Commit b53ef28

Browse files
ganteamyeroberts
authored andcommitted
TF: Fix generation repetition penalty with XLA (huggingface#18648)
1 parent 771d6c0 commit b53ef28

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/transformers/generation_tf_logits_process.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,11 @@ def _create_score_penalties(self, input_ids: tf.Tensor, logits: tf.Tensor) -> tf
262262

263263
# Scatters the penalties
264264
token_penalties = tf.ones(logits.shape)
265+
batch_size = input_ids.shape[0]
266+
seq_len = tf.shape(input_ids)[1] # the sequence length has dynamic size, hence the dynamic shape
265267
indexable_prev_input_ids = tf.concat(
266268
(
267-
tf.expand_dims(tf.repeat(tf.range(input_ids.shape[0]), input_ids.shape[1]), axis=-1),
269+
tf.expand_dims(tf.repeat(tf.range(batch_size), seq_len), axis=-1),
268270
tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1),
269271
),
270272
axis=1,

0 commit comments

Comments
 (0)