Skip to content

add reference model logps to chunkedloss interface and fix dpo loss fn #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 22, 2024

Conversation

shivam15s
Copy link
Collaborator

@shivam15s shivam15s commented Nov 21, 2024

accomodate reference model logps in chunked loss interface and make dpo loss use reference model logps in its loss function

Summary

as title

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@shivam15s shivam15s changed the title change interface to accomodate reference model logps [chunked loss] accomodate reference model logps and dpo loss w ref Nov 21, 2024
@shivam15s shivam15s changed the title [chunked loss] accomodate reference model logps and dpo loss w ref [chunkedloss] add reference model logps and add ref logps to dpo Nov 21, 2024
Copy link
Collaborator

@pramodith pramodith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Wrapping DPO up.

input_chunk, ref_weight, target_chunk, ref_bias=None, ignore_index=-100
):
with torch.no_grad():
ref_logits_chunk = input_chunk @ ref_weight.t()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic to get the log_probs is common for both the reference and active policy model, wondering if we can create a util function that can be called by both instead of repeating the code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be possible; just need the util fn to compute nll loss for the active policy forward

@shivam15s shivam15s changed the title [chunkedloss] add reference model logps and add ref logps to dpo add reference model logps to chunkedloss interface and fix dpo loss fn Nov 21, 2024
@shivam15s shivam15s requested a review from pramodith November 21, 2024 23:57
Copy link
Collaborator

@pramodith pramodith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the refactor!

@austin362667 austin362667 mentioned this pull request Nov 22, 2024
3 tasks
@shivam15s shivam15s merged commit d907ec0 into main Nov 22, 2024
3 checks passed
@shivam15s shivam15s deleted the shisahni/dpo_ref branch November 22, 2024 05:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants