-
Notifications
You must be signed in to change notification settings - Fork 397
Add Chunked ORPO Loss #362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@staticmethod | ||
def forward(ctx, _input, weight, target, bias=None, ignore_index=-100, beta=0.1, compiled=True): | ||
""" | ||
Fused linear forward function with ORPO (Odds-Ratio Preference Optimization). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a reference link to the paper or a blog explaining the function.
ignore_index (int): Index to ignore for loss computation. | ||
compiled (bool): Whether to use compiled mode for chunk accumulation. | ||
""" | ||
CHUNK_SIZE = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a TODO to tune this param? We might be under-utilizing GPU memory if we number of chunks is always set to batch_size/2.
Args: | ||
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size). | ||
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). | ||
bias (torch.Tensor, optional): Bias tensor. Shape: (hidden_size,). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't shape of the bias term be = vocab_size
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're also missing a doc string for target
reduction="sum", | ||
ignore_index=ignore_index, | ||
) | ||
chosen_nll_loss = chosen_nll_loss / (target[:target.shape[0]//2] != ignore_index).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If my understanding is correct, this is computing the mean nll loss for the chosen/accepted response, why can't we just use mean
as the reduction method above rather than computing the sum first and then dividing by the number of non-pad tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The target variable here is the whole target while target_chunk is the target for a given chunk. Normalizing directly using nll_loss reduction=mean would normalize with target_chunk length which is mathematically wrong. So this is kinda a hacky way to do normalization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh gotcha, that makes sense, thanks for the explanation.
|
||
def accumulate_chunk(input_chunk, target_chunk): | ||
if bias is not None: | ||
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), chunk_loss = torch.func.grad_and_value( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TIL about torch.func.grad_and_value
! Cool stuff!
else: | ||
return (per_token_logps * loss_mask).sum(-1) | ||
|
||
def odds_ratio_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably add a link to the original source of the code: https://github.com/huggingface/trl/blob/v0.11.3/trl/trainer/orpo_trainer.py#L75
test/chunked_loss/test_orpo_loss.py
Outdated
@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) | ||
def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta): | ||
# Define input tensors | ||
_tensor = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of scalar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the intention is to support testing inputs with variances other than one. However, the current tests do not yet include cases with non-unit variances.
test/chunked_loss/test_orpo_loss.py
Outdated
loss1.backward() | ||
loss2.backward() | ||
# Compare the gradients | ||
assert torch.allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also add assertions for the grad on the weight and bias tensors too!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great stuff! Most of my comments are around documentation, wondering if we should also include a benchmark script for this one!
@shivam15s great work! @pramodith thanks for the detailed review. |
chosen_logps = average_log_prob[:len_chosen_chunk] | ||
rejected_logps = average_log_prob[len_chosen_chunk:] | ||
|
||
or_loss = odds_ratio_loss(chosen_logps, rejected_logps, beta=beta) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we rename this to odds ratio loss?
The returned is orpo loss = sft loss + odds ratio loss, not this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify, or_loss here is the odds ratio loss; while the loss returned from the fn is the sum of odds ratio loss (or_loss) and sft loss (chosen_nll_loss). Do you want to rename or_loss to something more verbose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarification. Could we rename or_loss to odds_ratio_loss since or_loss might be misunderstood to be orpo_loss.
or_loss = odds_ratio_loss(chosen_logps, rejected_logps, beta=beta) | ||
or_loss = or_loss / (target.shape[0] // 2) | ||
|
||
loss = chosen_nll_loss - or_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this class take a flag like enable_sft_loss?
Since users might have their customized sft loss logic and they needs odds ratio loss only. When the flag is True, this class computes chosen_nll_loss and returns chosen_nll_loss - or_loss, otherwise, it computes and returns odds_ratio_loss only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, or we can add an input parameter chosen_nll_loss_weight
and loss = chosen_nll_loss_weight* chosen_nll_loss - or_loss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is already a beta to balance the weights between sft loss and odds_ratio loss.
loss = sft loss + beta * odds ratio loss
f7ff450
to
cb08c89
Compare
7f76b17
to
7a9e1db
Compare
ignore_index=-100, | ||
beta=0.1, | ||
compute_nll_loss=True, | ||
compiled=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is compiled
used?
|
||
class LigerFusedLinearPreferenceBase(torch.autograd.Function): | ||
@staticmethod | ||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general suggestion:
-
can we make loss_fn a member method in the base class and raise not implemented exception and the derive class needs to implement it? instead of being a separate python function
-
can we use
__super__
instead ofLigerFusedLinearPreferenceBase.forward
(backward). Or we can not do that for static method.
I will merge this PR first and we can address later! Great contribution!!
## Summary This PR refactors the `LigerFusedLinearPreferenceBase` class to contain an abstractmethod corresponding to the calculation of the loss that needs to be implemented by all sub-classes. It also adds a new function to the class called `_compute_loss` which is mostly the same as the `_compute_orpo_loss` function introduced in #362 but makes it generic to calculate the NLL/Cross Entropy Loss plus accepts a custom loss function that implements a new alignment loss function. Most RLHF/RLAIF/Alignment algorithms state their final loss as `NLL + Beta * (Alignment_Loss) `so adding the NLL logic inside the base class reduces repeated code. The _compute_loss function accepts ## Testing Done On A100-80G-SXM - Hardware Type: <BLANK> - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [X] run `make test-convergence` to ensure convergence --------- Co-authored-by: pramodith <[email protected]>
## Summary Add support for a fused, torch-compiled, and chunked DPO ([Direct Preference Optimization](https://arxiv.org/html/2305.18290v3)) loss kernel, as requested in #371. This implementation is largely based on the excellent work done on ORPO (#362) by @shivam15s. ### DPO Loss Formulation In a reference setting (not reference free): $$r_\theta(x,y_c) - r_\theta(x,y_r) = \log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x))$$ $$-\log(\sigma((\log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x)) - \log(\pi_{\theta_{\text{ref}}}(y_c|x)) + \log(\pi_{\theta_{\text{ref}}}(y_r|x)))/\beta))$$ Corresponds to: ```python # Policy model log probabilities policy_chosen_logps = log_probs(policy_chosen_logits) policy_rejected_logps = log_probs(policy_rejected_logits) # Reference model log probabilities ref_chosen_logps = log_probs(ref_chosen_logits) ref_rejected_logps = log_probs(ref_rejected_logits) # Compute advantages chosen_advantages = policy_chosen_logps - ref_chosen_logps rejected_advantages = policy_rejected_logps - ref_rejected_logps # DPO loss logits_diff = (chosen_advantages - rejected_advantages) / beta losses = -F.logsigmoid(logits_diff) ``` In this PR: 1. The above mathematical equation shows that to maximize the reward difference, we get formula: $$r_θ(x_c) - r_θ(x_r)$$ 2. This can be further optimized using just: $$-log(σ((π_θ(x_c) - π_θ(x_r))/β))$$ 3. So, the code implements: ```python logits_diff = (chosen_logps - rejected_logps) / beta # (π_θ(x_c) - π_θ(x_r))/β losses = -F.logsigmoid(logits_diff) # -log(σ(logits_diff)) ``` 4. Sum up DPO and NLL: $$L_{DPO+NLL} = L_{DPO}+αL_{NLL}$$ ## Testing Done   - Hardware Type: **NVIDIA L40S (48G)** - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [X] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <[email protected]> Co-authored-by: shivam15s <[email protected]>
Summary
Adds chunked ORPO loss kernel
Testing Done
Benchmarks


References:
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence