@@ -217,33 +217,43 @@ def run_evaluate(data_loader,
217217 for eval_step , batch in enumerate (data_loader ):
218218 input_ids , segment_ids , input_mask , masked_lm_positions , \
219219 masked_lm_labels , next_sentence_labels = batch
220+ with paddle .amp .auto_cast (args .use_amp ,
221+ custom_white_list = [
222+ 'softmax' ,
223+ 'layer_norm' ,
224+ 'gelu' ,
225+ ],
226+ custom_black_list = [
227+ "c_softmax_with_cross_entropy" ,
228+ ],
229+ level = args .fp16_opt_level ):
220230
221- if args .binary_head :
222- prediction_scores , seq_relationship_score = model (
223- input_ids = input_ids ,
224- token_type_ids = segment_ids ,
225- position_ids = None ,
226- attention_mask = input_mask ,
227- masked_positions = masked_lm_positions )
228-
229- lm_loss , sop_loss = criterion (prediction_scores ,
230- seq_relationship_score ,
231- masked_lm_labels ,
232- next_sentence_labels )
233- loss = lm_loss + sop_loss
234- else :
235- prediction_scores = model (input_ids = input_ids ,
236- token_type_ids = segment_ids ,
237- position_ids = None ,
238- attention_mask = input_mask ,
239- masked_positions = masked_lm_positions )
240-
241- loss = criterion (prediction_scores , None , masked_lm_labels )
242-
243- loss_global ["loss" ] += loss .detach ()
244- if args .binary_head :
245- loss_global ["lm_loss" ] += lm_loss .detach ()
246- loss_global ["sop_loss" ] += sop_loss .detach ()
231+ if args .binary_head :
232+ prediction_scores , seq_relationship_score = model (
233+ input_ids = input_ids ,
234+ token_type_ids = segment_ids ,
235+ position_ids = None ,
236+ attention_mask = input_mask ,
237+ masked_positions = masked_lm_positions )
238+
239+ lm_loss , sop_loss = criterion (prediction_scores ,
240+ seq_relationship_score ,
241+ masked_lm_labels ,
242+ next_sentence_labels )
243+ loss = lm_loss + sop_loss
244+ else :
245+ prediction_scores = model (input_ids = input_ids ,
246+ token_type_ids = segment_ids ,
247+ position_ids = None ,
248+ attention_mask = input_mask ,
249+ masked_positions = masked_lm_positions )
250+
251+ loss = criterion (prediction_scores , None , masked_lm_labels )
252+
253+ loss_global ["loss" ] += loss .detach ()
254+ if args .binary_head :
255+ loss_global ["lm_loss" ] += lm_loss .detach ()
256+ loss_global ["sop_loss" ] += sop_loss .detach ()
247257
248258 if eval_step >= iter_steps - 1 :
249259 log_info_dict = dict ()
0 commit comments