Skip to content

Commit b703e6c

Browse files
committed
improve comment
1 parent fe9529f commit b703e6c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/test_modeling_tf_common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,9 +363,9 @@ def _make_attention_mask_non_null(self, inputs_dict):
363363
if k in inputs_dict:
364364
attention_mask = inputs_dict[k]
365365

366-
# # make sure no all 0s attention masks - to avoid failure at this moment.
367-
# # TODO: remove this line once the TODO below is implemented.
368-
# attention_mask = tf.ones_like(attention_mask, dtype=tf.int32)
366+
# Make sure no all 0s attention masks - to avoid failure at this moment.
367+
# Put `1` at the beginning of sequences to make it still work when combining causal attention masks.
368+
# TODO: remove this line once a fix regarding large negative values for attention mask is done.
369369
attention_mask = tf.concat(
370370
[tf.ones_like(attention_mask[:, :1], dtype=attention_mask.dtype), attention_mask[:, 1:]], axis=-1
371371
)

0 commit comments

Comments
 (0)