Skip to content

Commit fe9529f

Browse files
committed
add docstring
1 parent 5a41d5c commit fe9529f

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tests/test_modeling_tf_common.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,9 @@ def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_cla
417417

418418
return new_tf_outputs, new_pt_outputs
419419

420-
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, name="outputs", attributes=None):
421-
"""
420+
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
421+
"""Check the outputs from PyTorch and TensorFlow models are closed enough. Checks are done in a recursive way.
422+
422423
Args:
423424
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
424425
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Mainly used for providing more informative
@@ -452,7 +453,7 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, name="outputs
452453
# appending each key to the current (string) `names`
453454
attributes = tuple([f"{name}.{k}" for k in tf_keys])
454455
self.check_pt_tf_outputs(
455-
tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, name=name, attributes=attributes
456+
tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
456457
)
457458

458459
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
@@ -472,7 +473,7 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, name="outputs
472473
attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
473474

474475
for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
475-
self.check_pt_tf_outputs(tf_output, pt_output, model_class, name=attr)
476+
self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)
476477

477478
elif isinstance(tf_outputs, tf.Tensor):
478479
self.assertTrue(
@@ -500,7 +501,7 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, name="outputs
500501
tf_outputs[pt_nans] = 0
501502

502503
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
503-
self.assertLessEqual(max_diff, 1e-5, f"{name}: Difference between torch and tf is {max_diff} (>= {1e-5}).")
504+
self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
504505
else:
505506
raise ValueError(
506507
f"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."

0 commit comments

Comments
 (0)