Skip to content

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Apr 1, 2022

What does this PR do?

Improve PT/TF equivalence test.

To make the review a bit easier for you, I made some comments. And here are a summary of changes:

  • test_pt_tf_model_equivalence in TensorFlow LED and CLIP are removed: the common one can handle it.
  • test_pt_tf_model_equivalence in TensorFlow LXMERT and ViTMAE are removed: we only need to overwrite
    • prepare_pt_inputs_from_tf_inputs for LXMERT
    • check_pt_tf_models for ViTMAE
  • Main changes in TFModelTesterMixin.test_pt_tf_model_equivalence
    • restructure the code into components, so they could be overwritten separately instead of the whole big block
    • move some ugly (temporary) logic blocks outside:
      • _make_attention_mask_non_null
      • _postprocessing_to_ignore_test_cases
    • About check_pt_tf_outputs:
      • it now can handle instances of ModelOutput (for CLIP model)
      • better failure message: print the tensor name where the large diff between PT/TF occurs, like output.hidden_states or output.text_model_output.attentions_1
    • A better way to handle the cases where PT/TF outputs have different keys: we try to test the output values for the common keys in both outputs.

Once this PR is approved/merged:

  • To work on the same PT/TF equivalence test on PT side (should be very quick)
  • To apply the same logic to PT/Flax equivalence test, both on Flax and PT sides.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 1, 2022

The documentation is not available anymore as the PR was closed or merged.

@@ -497,130 +496,6 @@ def test_keras_save_load(self):
after_outputs = model(inputs_dict)
self.assert_outputs_same(after_outputs, outputs)

# overwrite from common since CLIPModel/TFCLIPModel return CLIPOutput/TFCLIPOutput
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need this anymore - the test in TF common can handle nested outputs, including instances of ModelOutput.

# TODO: Remove this once a more thorough pt/tf equivalence could be implemented in `test_modeling_tf_common.py`.
# (Currently, such a test will fail some other model tests: it requires some time to fix them.)
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence_extra(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done before to make TF-LED having a strong test, while the common version was still a loose test.

Now the common test is (very) strong, we no longer need this test in TF-LED test.

if not is_torch_available():
return

def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict):
import torch
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I add import torch here without is_torch_available or require_torch? This method will be called only inside test_pt_tf_model_equivalence, which is already decorated with is_pt_tf_cross_test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's just a marker that reads an env variable, so I think it should have the require_torch just in case, but I'm not sure if we are very consistent with that. @LysandreJik might know better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it really matters as it is indeed already decorated with the is_pt_tf_cross_Test. We don't have a convention set, so feel free to choose the simplest approach.

Comment on lines +497 to +500
if isinstance(value, dict):
pt_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
elif isinstance(value, (list, tuple)):
pt_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value) for iter_value in value)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the specific part for LXMERT test.

(It is possible to move this part to the common PT/TF test method. But I think it's fine/better to overwrite here.)

Comment on lines -528 to -532
def torch_type(key):
if key in ("visual_feats", "visual_pos"):
return torch.float32
else:
return torch.long
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. The new version uses

elif tf_inputs_dict[key].dtype.is_floating:

I find it's cleaner and more general.

Comment on lines -534 to -546
def recursive_numpy_convert(iterable):
return_dict = {}
for key, value in iterable.items():
if isinstance(value, dict):
return_dict[key] = recursive_numpy_convert(value)
else:
if isinstance(value, (list, tuple)):
return_dict[key] = (
torch.from_numpy(iter_value.numpy()).to(torch_type(key)) for iter_value in value
)
else:
return_dict[key] = torch.from_numpy(value.numpy()).to(torch_type(key))
return return_dict
Copy link
Collaborator Author

@ydshieh ydshieh Apr 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the new version, this is handled in prepare_pt_inputs_from_tf_inputs.

if isinstance(value, dict):
    pt_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
elif isinstance(value, (list, tuple)):
    pt_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value)

@@ -486,135 +488,31 @@ def check_hidden_states_output(config, inputs_dict, model_class):
config.output_hidden_states = True
check_hidden_states_output(config, inputs_dict, model_class)

def test_pt_tf_model_equivalence(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the new version, we only need to overwrite prepare_pt_inputs_from_tf_inputs, because that is the place with actual differences from the common version.


check_pt_tf_models(tf_model, pt_model)
super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to call super() here, because the difference is only about adding a noise argument in the block above.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

@@ -363,140 +363,20 @@ def check_hidden_states_output(inputs_dict, config, model_class):

# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just need to overwrite check_pt_tf_models.

@ydshieh ydshieh changed the title [WIP] Improve pt tf equiv test Improve PT/TF equivalence test Apr 7, 2022
@ydshieh ydshieh marked this pull request as ready for review April 7, 2022 13:27
being a named field in the output.
"""

self.assertEqual(type(name), str)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we should test this argument. I think it is not worth it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now sure why it was added, but it doesn't look useful I agree.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was added by me during the process: sometimes I passed the wrong arguments and got errors.

However, those arguments are unlikely to be used by anyone else (unless someone want to change check_pt_tf_outputs)

tf_outputs[pt_nans] = 0

max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make the failure message more informative by adding the corresponding tensor name, like

output.hidden_states

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning those. It's great we can remove some model-specific code to rely on the generic common tests!

if not is_torch_available():
return

def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict):
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's just a marker that reads an env variable, so I think it should have the require_torch just in case, but I'm not sure if we are very consistent with that. @LysandreJik might know better.

being a named field in the output.
"""

self.assertEqual(type(name), str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now sure why it was added, but it doesn't look useful I agree.


check_pt_tf_models(tf_model, pt_model)
super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, it makes writing tests for edge cases much easier 🚀

@ydshieh ydshieh force-pushed the improve_pt_tf_equiv_test branch from cdae60f to b703e6c Compare April 11, 2022 19:41
@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 11, 2022

(just rebase on main - no real change since your last review)

@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 11, 2022

Merge now. Don't hesitate to leave comments in any :-)

@ydshieh ydshieh merged commit dce33f2 into huggingface:main Apr 11, 2022
@ydshieh ydshieh deleted the improve_pt_tf_equiv_test branch April 11, 2022 20:19
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* add error message

* Use names in the error message

* allow ModelOutput

* rename to check_pt_tf_outputs and move outside

* fix style

* skip past_key_values in a better way

* Add comments

* improve code for label/loss

* make the logic clear by moving the ignore keys out

* fix _postprocessing_to_ignore

* fix _postprocessing_to_ignore: create new outputs from the remaining fields

* ignore past_key_values in TFGPT2 models for now

* make check_pt_tf_outputs better regarding names

* move check_pt_tf_models outside

* rename methods

* remove test_pt_tf_model_equivalence in TFCLIPModelTest

* Reduce TFViTMAEModelTest.test_pt_tf_model_equivalence

* move prepare_pt_inputs_from_tf_inputs outside check_pt_tf_models

* Fix quality

* Clean-up TFLxmertModelTester.test_pt_tf_model_equivalence

* Fix quality

* fix

* fix style

* Clean-up TFLEDModelTest.test_pt_tf_model_equivalence

* Fix quality

* add docstring

* improve comment

Co-authored-by: ydshieh <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants