-
Notifications
You must be signed in to change notification settings - Fork 30.2k
Improve PT/TF equivalence test #16557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
ea4923f
to
df2fc56
Compare
5fa6bd2
to
b1c194d
Compare
@@ -497,130 +496,6 @@ def test_keras_save_load(self): | |||
after_outputs = model(inputs_dict) | |||
self.assert_outputs_same(after_outputs, outputs) | |||
|
|||
# overwrite from common since CLIPModel/TFCLIPModel return CLIPOutput/TFCLIPOutput | |||
@is_pt_tf_cross_test | |||
def test_pt_tf_model_equivalence(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need this anymore - the test in TF common can handle nested outputs, including instances of ModelOutput
.
# TODO: Remove this once a more thorough pt/tf equivalence could be implemented in `test_modeling_tf_common.py`. | ||
# (Currently, such a test will fail some other model tests: it requires some time to fix them.) | ||
@is_pt_tf_cross_test | ||
def test_pt_tf_model_equivalence_extra(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was done before to make TF-LED having a strong test, while the common version was still a loose test.
Now the common test is (very) strong, we no longer need this test in TF-LED test.
if not is_torch_available(): | ||
return | ||
|
||
def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict): | ||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I add import torch
here without is_torch_available
or require_torch
? This method will be called only inside test_pt_tf_model_equivalence
, which is already decorated with is_pt_tf_cross_test
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's just a marker that reads an env variable, so I think it should have the require_torch
just in case, but I'm not sure if we are very consistent with that. @LysandreJik might know better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it really matters as it is indeed already decorated with the is_pt_tf_cross_Test
. We don't have a convention set, so feel free to choose the simplest approach.
if isinstance(value, dict): | ||
pt_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value) | ||
elif isinstance(value, (list, tuple)): | ||
pt_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value) for iter_value in value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the specific part for LXMERT test.
(It is possible to move this part to the common PT/TF test method. But I think it's fine/better to overwrite here.)
def torch_type(key): | ||
if key in ("visual_feats", "visual_pos"): | ||
return torch.float32 | ||
else: | ||
return torch.long |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed. The new version uses
elif tf_inputs_dict[key].dtype.is_floating:
I find it's cleaner and more general.
def recursive_numpy_convert(iterable): | ||
return_dict = {} | ||
for key, value in iterable.items(): | ||
if isinstance(value, dict): | ||
return_dict[key] = recursive_numpy_convert(value) | ||
else: | ||
if isinstance(value, (list, tuple)): | ||
return_dict[key] = ( | ||
torch.from_numpy(iter_value.numpy()).to(torch_type(key)) for iter_value in value | ||
) | ||
else: | ||
return_dict[key] = torch.from_numpy(value.numpy()).to(torch_type(key)) | ||
return return_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the new version, this is handled in prepare_pt_inputs_from_tf_inputs
.
if isinstance(value, dict):
pt_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
elif isinstance(value, (list, tuple)):
pt_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value)
@@ -486,135 +488,31 @@ def check_hidden_states_output(config, inputs_dict, model_class): | |||
config.output_hidden_states = True | |||
check_hidden_states_output(config, inputs_dict, model_class) | |||
|
|||
def test_pt_tf_model_equivalence(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the new version, we only need to overwrite prepare_pt_inputs_from_tf_inputs
, because that is the place with actual differences from the common version.
|
||
check_pt_tf_models(tf_model, pt_model) | ||
super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to call super()
here, because the difference is only about adding a noise
argument in the block above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM!
@@ -363,140 +363,20 @@ def check_hidden_states_output(inputs_dict, config, model_class): | |||
|
|||
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise | |||
# to generate masks during test | |||
@is_pt_tf_cross_test | |||
def test_pt_tf_model_equivalence(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We just need to overwrite check_pt_tf_models
.
being a named field in the output. | ||
""" | ||
|
||
self.assertEqual(type(name), str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if we should test this argument. I think it is not worth it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now sure why it was added, but it doesn't look useful I agree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was added by me during the process: sometimes I passed the wrong arguments and got errors.
However, those arguments are unlikely to be used by anyone else (unless someone want to change check_pt_tf_outputs
)
tf_outputs[pt_nans] = 0 | ||
|
||
max_diff = np.amax(np.abs(tf_outputs - pt_outputs)) | ||
self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make the failure message more informative by adding the corresponding tensor name, like
output.hidden_states
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning those. It's great we can remove some model-specific code to rely on the generic common tests!
if not is_torch_available(): | ||
return | ||
|
||
def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict): | ||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's just a marker that reads an env variable, so I think it should have the require_torch
just in case, but I'm not sure if we are very consistent with that. @LysandreJik might know better.
being a named field in the output. | ||
""" | ||
|
||
self.assertEqual(type(name), str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now sure why it was added, but it doesn't look useful I agree.
|
||
check_pt_tf_models(tf_model, pt_model) | ||
super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great, it makes writing tests for edge cases much easier 🚀
cdae60f
to
b703e6c
Compare
(just rebase on main - no real change since your last review) |
Merge now. Don't hesitate to leave comments in any :-) |
* add error message * Use names in the error message * allow ModelOutput * rename to check_pt_tf_outputs and move outside * fix style * skip past_key_values in a better way * Add comments * improve code for label/loss * make the logic clear by moving the ignore keys out * fix _postprocessing_to_ignore * fix _postprocessing_to_ignore: create new outputs from the remaining fields * ignore past_key_values in TFGPT2 models for now * make check_pt_tf_outputs better regarding names * move check_pt_tf_models outside * rename methods * remove test_pt_tf_model_equivalence in TFCLIPModelTest * Reduce TFViTMAEModelTest.test_pt_tf_model_equivalence * move prepare_pt_inputs_from_tf_inputs outside check_pt_tf_models * Fix quality * Clean-up TFLxmertModelTester.test_pt_tf_model_equivalence * Fix quality * fix * fix style * Clean-up TFLEDModelTest.test_pt_tf_model_equivalence * Fix quality * add docstring * improve comment Co-authored-by: ydshieh <[email protected]>
What does this PR do?
Improve PT/TF equivalence test.
To make the review a bit easier for you, I made some comments. And here are a summary of changes:
test_pt_tf_model_equivalence
in TensorFlowLED
andCLIP
are removed: the common one can handle it.test_pt_tf_model_equivalence
in TensorFlowLXMERT
andViTMAE
are removed: we only need to overwriteprepare_pt_inputs_from_tf_inputs
forLXMERT
check_pt_tf_models
forViTMAE
TFModelTesterMixin.test_pt_tf_model_equivalence
_make_attention_mask_non_null
_postprocessing_to_ignore_test_cases
check_pt_tf_outputs
:ModelOutput
(for CLIP model)output.hidden_states
oroutput.text_model_output.attentions_1
Once this PR is approved/merged: