Skip to content

Commit e03966e

Browse files
gantesgugger
andauthored
TF: XLA stable softmax (#16892)
Co-authored-by: Sylvain Gugger <[email protected]>
1 parent 8246caf commit e03966e

File tree

49 files changed

+210
-142
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+210
-142
lines changed

src/transformers/generation_tf_logits_process.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import tensorflow as tf
2121

22+
from .tf_utils import stable_softmax
2223
from .utils import add_start_docstrings
2324
from .utils.logging import get_logger
2425

@@ -166,7 +167,7 @@ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
166167
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
167168

168169
mask_scores = tf.fill(scores.shape, self.filter_value)
169-
cumulative_probs = tf.math.cumsum(tf.nn.softmax(topk_scores, axis=-1), axis=-1)
170+
cumulative_probs = tf.math.cumsum(stable_softmax(topk_scores, axis=-1), axis=-1)
170171
score_mask = cumulative_probs < self.top_p
171172

172173
# Also include the token that is higher than top_p (the first false = shift and insert a True on the left)

src/transformers/generation_tf_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
TFTopKLogitsWarper,
3535
TFTopPLogitsWarper,
3636
)
37-
from .tf_utils import shape_list
37+
from .tf_utils import shape_list, stable_softmax
3838
from .utils import ModelOutput, logging
3939

4040

@@ -3060,7 +3060,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
30603060
logits, sorted_indices, axis=-1, batch_dims=1
30613061
) # expects logits to be of dim (batch_size, vocab_size)
30623062

3063-
cumulative_probs = tf.math.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
3063+
cumulative_probs = tf.math.cumsum(stable_softmax(sorted_logits, axis=-1), axis=-1)
30643064

30653065
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
30663066
sorted_indices_to_remove = cumulative_probs > top_p

src/transformers/models/albert/modeling_tf_albert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
keras_serializable,
4545
unpack_inputs,
4646
)
47-
from ...tf_utils import shape_list
47+
from ...tf_utils import shape_list, stable_softmax
4848
from ...utils import (
4949
MULTIPLE_CHOICE_DUMMY_INPUTS,
5050
ModelOutput,
@@ -259,7 +259,7 @@ def call(
259259
attention_scores = tf.add(attention_scores, attention_mask)
260260

261261
# Normalize the attention scores to probabilities.
262-
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
262+
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
263263

264264
# This is actually dropping out entire tokens to attend to, which might
265265
# seem a bit unusual, but is taken from the original Transformer paper.

src/transformers/models/bart/modeling_tf_bart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
keras_serializable,
4141
unpack_inputs,
4242
)
43-
from ...tf_utils import shape_list
43+
from ...tf_utils import shape_list, stable_softmax
4444
from ...utils import (
4545
add_code_sample_docstrings,
4646
add_end_docstrings,
@@ -244,7 +244,7 @@ def call(
244244
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
245245
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
246246

247-
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
247+
attn_weights = stable_softmax(attn_weights, axis=-1)
248248

249249
if layer_head_mask is not None:
250250
# The tf.debugging asserts are not compliant with XLA then they

src/transformers/models/bert/modeling_tf_bert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
keras_serializable,
5050
unpack_inputs,
5151
)
52-
from ...tf_utils import shape_list
52+
from ...tf_utils import shape_list, stable_softmax
5353
from ...utils import (
5454
DUMMY_INPUTS,
5555
MULTIPLE_CHOICE_DUMMY_INPUTS,
@@ -322,7 +322,7 @@ def call(
322322
attention_scores = tf.add(attention_scores, attention_mask)
323323

324324
# Normalize the attention scores to probabilities.
325-
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
325+
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
326326

327327
# This is actually dropping out entire tokens to attend to, which might
328328
# seem a bit unusual, but is taken from the original Transformer paper.

src/transformers/models/blenderbot/modeling_tf_blenderbot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
keras_serializable,
4141
unpack_inputs,
4242
)
43-
from ...tf_utils import shape_list
43+
from ...tf_utils import shape_list, stable_softmax
4444
from ...utils import (
4545
add_code_sample_docstrings,
4646
add_end_docstrings,
@@ -245,7 +245,7 @@ def call(
245245
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
246246
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
247247

248-
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
248+
attn_weights = stable_softmax(attn_weights, axis=-1)
249249

250250
if layer_head_mask is not None:
251251
# The tf.debugging asserts are not compliant with XLA then they

src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
keras_serializable,
4040
unpack_inputs,
4141
)
42-
from ...tf_utils import shape_list
42+
from ...tf_utils import shape_list, stable_softmax
4343
from ...utils import (
4444
add_code_sample_docstrings,
4545
add_end_docstrings,
@@ -245,7 +245,7 @@ def call(
245245
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
246246
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
247247

248-
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
248+
attn_weights = stable_softmax(attn_weights, axis=-1)
249249

250250
if layer_head_mask is not None:
251251
# The tf.debugging asserts are not compliant with XLA then they

src/transformers/models/clip/modeling_tf_clip.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
keras_serializable,
3535
unpack_inputs,
3636
)
37-
from ...tf_utils import shape_list
37+
from ...tf_utils import shape_list, stable_softmax
3838
from ...utils import (
3939
ModelOutput,
4040
add_start_docstrings,
@@ -333,7 +333,7 @@ def call(
333333
attention_scores = tf.add(attention_scores, attention_mask)
334334

335335
# Normalize the attention scores to probabilities.
336-
_attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
336+
_attention_probs = stable_softmax(logits=attention_scores, axis=-1)
337337

338338
# This is actually dropping out entire tokens to attend to, which might
339339
# seem a bit unusual, but is taken from the original Transformer paper.

src/transformers/models/convbert/modeling_tf_convbert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
keras_serializable,
4343
unpack_inputs,
4444
)
45-
from ...tf_utils import shape_list
45+
from ...tf_utils import shape_list, stable_softmax
4646
from ...utils import (
4747
MULTIPLE_CHOICE_DUMMY_INPUTS,
4848
add_code_sample_docstrings,
@@ -228,7 +228,7 @@ def call(self, hidden_states, attention_mask, head_mask, output_attentions, trai
228228

229229
conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer)
230230
conv_kernel_layer = tf.reshape(conv_kernel_layer, [-1, self.conv_kernel_size, 1])
231-
conv_kernel_layer = tf.nn.softmax(conv_kernel_layer, axis=1)
231+
conv_kernel_layer = stable_softmax(conv_kernel_layer, axis=1)
232232

233233
paddings = tf.constant(
234234
[
@@ -270,7 +270,7 @@ def call(self, hidden_states, attention_mask, head_mask, output_attentions, trai
270270
attention_scores = attention_scores + attention_mask
271271

272272
# Normalize the attention scores to probabilities.
273-
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
273+
attention_probs = stable_softmax(attention_scores, axis=-1)
274274

275275
# This is actually dropping out entire tokens to attend to, which might
276276
# seem a bit unusual, but is taken from the original Transformer paper.

src/transformers/models/ctrl/modeling_tf_ctrl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
keras_serializable,
3232
unpack_inputs,
3333
)
34-
from ...tf_utils import shape_list
34+
from ...tf_utils import shape_list, stable_softmax
3535
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
3636
from .configuration_ctrl import CTRLConfig
3737

@@ -79,7 +79,7 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
7979
attention_mask = tf.cast(attention_mask, dtype=scaled_attention_logits.dtype)
8080
scaled_attention_logits = scaled_attention_logits + attention_mask
8181

82-
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
82+
attention_weights = stable_softmax(scaled_attention_logits, axis=-1)
8383

8484
# Mask heads if we want to
8585
if head_mask is not None:

0 commit comments

Comments
 (0)