Skip to content

Commit b4b1bdc

Browse files
authored
[BugFix] Fix amp usage for evaluation. (#3303)
* fix eval of amp usage. * fix
1 parent 0711a60 commit b4b1bdc

File tree

1 file changed

+36
-26
lines changed

1 file changed

+36
-26
lines changed

model_zoo/ernie-1.0/run_pretrain.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -217,33 +217,43 @@ def run_evaluate(data_loader,
217217
for eval_step, batch in enumerate(data_loader):
218218
input_ids, segment_ids, input_mask, masked_lm_positions, \
219219
masked_lm_labels, next_sentence_labels = batch
220+
with paddle.amp.auto_cast(args.use_amp,
221+
custom_white_list=[
222+
'softmax',
223+
'layer_norm',
224+
'gelu',
225+
],
226+
custom_black_list=[
227+
"c_softmax_with_cross_entropy",
228+
],
229+
level=args.fp16_opt_level):
220230

221-
if args.binary_head:
222-
prediction_scores, seq_relationship_score = model(
223-
input_ids=input_ids,
224-
token_type_ids=segment_ids,
225-
position_ids=None,
226-
attention_mask=input_mask,
227-
masked_positions=masked_lm_positions)
228-
229-
lm_loss, sop_loss = criterion(prediction_scores,
230-
seq_relationship_score,
231-
masked_lm_labels,
232-
next_sentence_labels)
233-
loss = lm_loss + sop_loss
234-
else:
235-
prediction_scores = model(input_ids=input_ids,
236-
token_type_ids=segment_ids,
237-
position_ids=None,
238-
attention_mask=input_mask,
239-
masked_positions=masked_lm_positions)
240-
241-
loss = criterion(prediction_scores, None, masked_lm_labels)
242-
243-
loss_global["loss"] += loss.detach()
244-
if args.binary_head:
245-
loss_global["lm_loss"] += lm_loss.detach()
246-
loss_global["sop_loss"] += sop_loss.detach()
231+
if args.binary_head:
232+
prediction_scores, seq_relationship_score = model(
233+
input_ids=input_ids,
234+
token_type_ids=segment_ids,
235+
position_ids=None,
236+
attention_mask=input_mask,
237+
masked_positions=masked_lm_positions)
238+
239+
lm_loss, sop_loss = criterion(prediction_scores,
240+
seq_relationship_score,
241+
masked_lm_labels,
242+
next_sentence_labels)
243+
loss = lm_loss + sop_loss
244+
else:
245+
prediction_scores = model(input_ids=input_ids,
246+
token_type_ids=segment_ids,
247+
position_ids=None,
248+
attention_mask=input_mask,
249+
masked_positions=masked_lm_positions)
250+
251+
loss = criterion(prediction_scores, None, masked_lm_labels)
252+
253+
loss_global["loss"] += loss.detach()
254+
if args.binary_head:
255+
loss_global["lm_loss"] += lm_loss.detach()
256+
loss_global["sop_loss"] += sop_loss.detach()
247257

248258
if eval_step >= iter_steps - 1:
249259
log_info_dict = dict()

0 commit comments

Comments
 (0)