Skip to content

Commit 2e44dda

Browse files
author
gongel
committed
update attention mask
1 parent 28ea1e2 commit 2e44dda

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

model_zoo/gpt/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def _construct_sample(self, tokens):
448448
loss_mask[np.where(np.array(tokens) == self.eos_id)] = 0.0
449449
position_ids = np.arange(0, seq_length, dtype="int64")
450450

451-
attention_mask = loss_mask
451+
attention_mask = np.ones(seq_length, dtype="int64")
452452
labels = np.array(labels, dtype="int64")
453453
return [tokens, loss_mask, attention_mask, position_ids, labels]
454454

0 commit comments

Comments
 (0)