-
Notifications
You must be signed in to change notification settings - Fork 383
add reference model logps to chunkedloss interface and fix dpo loss fn #405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…s use reference model
116ee1a
to
15cb383
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Wrapping DPO up.
input_chunk, ref_weight, target_chunk, ref_bias=None, ignore_index=-100 | ||
): | ||
with torch.no_grad(): | ||
ref_logits_chunk = input_chunk @ ref_weight.t() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to get the log_probs is common for both the reference and active policy model, wondering if we can create a util function that can be called by both instead of repeating the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be possible; just need the util fn to compute nll loss for the active policy forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the refactor!
accomodate reference model logps in chunked loss interface and make dpo loss use reference model logps in its loss function
Summary
as title
Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence