-
Notifications
You must be signed in to change notification settings - Fork 406
Improve lead time support for diffusion models #980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
/blossom-ci |
/blossom-ci |
/blossom-ci |
/blossom-ci |
/blossom-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All changes proposed look good to me. Just a few details would require some improvements:
-
As far as I understand, this PR decouples lead-time embeddings from positional embeddings, in order to allow more flexibility in using them independently from each other. This new functionality does not seem to be used in any of the training recipes/examples. It could be useful to detail in the PR description the broader context (e.g. which applications is it going to be applied to? Will there be a follow-up PR? etc...)
-
The new flexibility to independently use lead-time and positional embeddings should be clearly explained in the docstrings.
-
IMO the current implementation of the lead-time embeddings has too many failure modes to be safely exposed to broader applications. For example, in
positional_embedding_indexing
:
lead_time_label
can be done whileself.lt_embd
is notNone
, which leads to an error- Conversely,
lead_time_label
could be a user-provided tensor, whileself.lt_embd
isNone
, which leads tolead_time_label
being silently ignored.
I strongly support better parameters validation to eliminate this failure modes, either in the forward
method or the __Init__
when possible.
Hi @CharlelieLrt,
I would say lead-time embeddings were already decoupled from positional embeddings before the PR. This PR just includes some fixes to make sure that they can be enabled when positional embeddings are disabled, or vice versa.
As they were already implemented independently, I don't think it's a new flexibility, but I can improve the docstrings in that regard.
I'll add some checks to make sure the inputs conform with the model configuration (but note that as far as I understand, these failure modes already existed before the PR). |
/blossom-ci |
58974a1
to
e2fdd39
Compare
/blossom-ci |
Signed-off-by: Charlelie Laurent <[email protected]>
Signed-off-by: Charlelie Laurent <[email protected]>
Signed-off-by: Charlelie Laurent <[email protected]>
Signed-off-by: Charlelie Laurent <[email protected]>
…e code Signed-off-by: Charlelie Laurent <[email protected]>
…bels Signed-off-by: Charlelie Laurent <[email protected]>
Signed-off-by: Charlelie Laurent <[email protected]>
Signed-off-by: Charlelie Laurent <[email protected]>
…ime models Signed-off-by: Charlelie Laurent <[email protected]>
Signed-off-by: Charlelie Laurent <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
/blossom-ci |
…ts accordingly Signed-off-by: Charlelie Laurent <[email protected]>
/blossom-ci |
PhysicsNeMo Pull Request
Description
Adds/fixes support for lead-time labels in various places where it was missing or not working:
SongUNetPosEmbd
now works properly using either, both or neither of positional embedding and lead-time embedding. In the previous version some pieces of code could try to access properties of these even when set toNone
.deterministic_sampler
now accepts lead-time labels and passes them through to the model, if given.EDMLoss
also now supports lead-time labels.Checklist
Dependencies
No new dependencies needed.