@@ -417,8 +417,9 @@ def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_cla
417
417
418
418
return new_tf_outputs , new_pt_outputs
419
419
420
- def check_pt_tf_outputs (self , tf_outputs , pt_outputs , model_class , name = "outputs" , attributes = None ):
421
- """
420
+ def check_pt_tf_outputs (self , tf_outputs , pt_outputs , model_class , tol = 1e-5 , name = "outputs" , attributes = None ):
421
+ """Check the outputs from PyTorch and TensorFlow models are closed enough. Checks are done in a recursive way.
422
+
422
423
Args:
423
424
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
424
425
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Mainly used for providing more informative
@@ -452,7 +453,7 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, name="outputs
452
453
# appending each key to the current (string) `names`
453
454
attributes = tuple ([f"{ name } .{ k } " for k in tf_keys ])
454
455
self .check_pt_tf_outputs (
455
- tf_outputs .to_tuple (), pt_outputs .to_tuple (), model_class , name = name , attributes = attributes
456
+ tf_outputs .to_tuple (), pt_outputs .to_tuple (), model_class , tol = tol , name = name , attributes = attributes
456
457
)
457
458
458
459
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
@@ -472,7 +473,7 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, name="outputs
472
473
attributes = tuple ([f"{ name } _{ idx } " for idx in range (len (tf_outputs ))])
473
474
474
475
for tf_output , pt_output , attr in zip (tf_outputs , pt_outputs , attributes ):
475
- self .check_pt_tf_outputs (tf_output , pt_output , model_class , name = attr )
476
+ self .check_pt_tf_outputs (tf_output , pt_output , model_class , tol = tol , name = attr )
476
477
477
478
elif isinstance (tf_outputs , tf .Tensor ):
478
479
self .assertTrue (
@@ -500,7 +501,7 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, name="outputs
500
501
tf_outputs [pt_nans ] = 0
501
502
502
503
max_diff = np .amax (np .abs (tf_outputs - pt_outputs ))
503
- self .assertLessEqual (max_diff , 1e-5 , f"{ name } : Difference between torch and tf is { max_diff } (>= { 1e-5 } )." )
504
+ self .assertLessEqual (max_diff , tol , f"{ name } : Difference between torch and tf is { max_diff } (>= { tol } )." )
504
505
else :
505
506
raise ValueError (
506
507
f"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got { type (tf_outputs )} instead."
0 commit comments