Skip to content

Conversation

laserkelvin
Copy link

This PR adds checks and informative error messages centered around target normalization.

  1. A check to BaseTaskModule._compute_losses, appending keys to a used_norm list as normalization is performed. When all the losses have been calculated, if used_norm is empty despite having requested normalization, we trigger a RuntimeError to stop training, and inform the user that no normalization was performed.
  2. Another subsequent check to make sure that the number of keys used matches the number of keys specified by the user.

@melo-gonzo I've made it trigger an error, rather than a warning message, since it's probably too easy to ignore warnings and rather have it complete training.

Closes #75

Lee, Kin Long Kelvin added 2 commits December 8, 2023 08:38
This change raises an error message for when the user has specified
normalization arguments, but none of them were used.

Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
Also check that the number of used keys matches the number of keys
passed by the user.

Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
@laserkelvin laserkelvin added ux User experience, quality of life changes training Issues related to model training labels Dec 8, 2023
@melo-gonzo
Copy link
Collaborator

I like the idea, but I'm not sure this actually does anything at the moment. The _make_normalizers function takes in whatever normalize_kwargs are passed, but if they do not contain any of the task keys, it just writes out default values to the expected task keys normalizer and omits whatever was passed in.

@laserkelvin
Copy link
Author

Okay that's good to know, I forgot that behavior. Let me rewrite this then!

@laserkelvin laserkelvin marked this pull request as draft December 8, 2023 18:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training Issues related to model training ux User experience, quality of life changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Normalization keys mismatch fails silently
2 participants