Skip to content

Conversation

shivam15s
Copy link
Collaborator

@shivam15s shivam15s commented Nov 8, 2024

Summary

Adds chunked ORPO loss kernel

Testing Done

Benchmarks
Speed ORPO
Mem ORPO

References:

  1. Torch compiled FLCE is 2x faster than the current FLCE #227
  2. https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899#file-lce_benchmark-py
  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@shivam15s shivam15s marked this pull request as draft November 8, 2024 01:25
@staticmethod
def forward(ctx, _input, weight, target, bias=None, ignore_index=-100, beta=0.1, compiled=True):
"""
Fused linear forward function with ORPO (Odds-Ratio Preference Optimization).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a reference link to the paper or a blog explaining the function.

ignore_index (int): Index to ignore for loss computation.
compiled (bool): Whether to use compiled mode for chunk accumulation.
"""
CHUNK_SIZE = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a TODO to tune this param? We might be under-utilizing GPU memory if we number of chunks is always set to batch_size/2.

Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
bias (torch.Tensor, optional): Bias tensor. Shape: (hidden_size,).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't shape of the bias term be = vocab_size ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're also missing a doc string for target

reduction="sum",
ignore_index=ignore_index,
)
chosen_nll_loss = chosen_nll_loss / (target[:target.shape[0]//2] != ignore_index).sum()
Copy link
Collaborator

@pramodith pramodith Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my understanding is correct, this is computing the mean nll loss for the chosen/accepted response, why can't we just use mean as the reduction method above rather than computing the sum first and then dividing by the number of non-pad tokens?

Copy link
Collaborator Author

@shivam15s shivam15s Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The target variable here is the whole target while target_chunk is the target for a given chunk. Normalizing directly using nll_loss reduction=mean would normalize with target_chunk length which is mathematically wrong. So this is kinda a hacky way to do normalization

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh gotcha, that makes sense, thanks for the explanation.


def accumulate_chunk(input_chunk, target_chunk):
if bias is not None:
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), chunk_loss = torch.func.grad_and_value(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL about torch.func.grad_and_value! Cool stuff!

else:
return (per_token_logps * loss_mask).sum(-1)

def odds_ratio_loss(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably add a link to the original source of the code: https://github.com/huggingface/trl/blob/v0.11.3/trl/trainer/orpo_trainer.py#L75

@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)])
def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta):
# Define input tensors
_tensor = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of scalar?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the intention is to support testing inputs with variances other than one. However, the current tests do not yet include cases with non-unit variances.

loss1.backward()
loss2.backward()
# Compare the gradients
assert torch.allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add assertions for the grad on the weight and bias tensors too!

Copy link
Collaborator

@pramodith pramodith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great stuff! Most of my comments are around documentation, wondering if we should also include a benchmark script for this one!

@lancerts
Copy link
Collaborator

lancerts commented Nov 9, 2024

Great stuff! Most of my comments are around documentation, wondering if we should also include a benchmark script for this one!

@shivam15s great work! @pramodith thanks for the detailed review.
And yes, we should add the bench script

chosen_logps = average_log_prob[:len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk:]

or_loss = odds_ratio_loss(chosen_logps, rejected_logps, beta=beta)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we rename this to odds ratio loss?

The returned is orpo loss = sft loss + odds ratio loss, not this one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, or_loss here is the odds ratio loss; while the loss returned from the fn is the sum of odds ratio loss (or_loss) and sft loss (chosen_nll_loss). Do you want to rename or_loss to something more verbose?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarification. Could we rename or_loss to odds_ratio_loss since or_loss might be misunderstood to be orpo_loss.

or_loss = odds_ratio_loss(chosen_logps, rejected_logps, beta=beta)
or_loss = or_loss / (target.shape[0] // 2)

loss = chosen_nll_loss - or_loss

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this class take a flag like enable_sft_loss?

Since users might have their customized sft loss logic and they needs odds ratio loss only. When the flag is True, this class computes chosen_nll_loss and returns chosen_nll_loss - or_loss, otherwise, it computes and returns odds_ratio_loss only.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, or we can add an input parameter chosen_nll_loss_weight and loss = chosen_nll_loss_weight* chosen_nll_loss - or_loss

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already a beta to balance the weights between sft loss and odds_ratio loss.
loss = sft loss + beta * odds ratio loss

@shivam15s shivam15s marked this pull request as ready for review November 14, 2024 01:34
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
compiled=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is compiled used?


class LigerFusedLinearPreferenceBase(torch.autograd.Function):
@staticmethod
def forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

general suggestion:

  1. can we make loss_fn a member method in the base class and raise not implemented exception and the derive class needs to implement it? instead of being a separate python function

  2. can we use __super__ instead of LigerFusedLinearPreferenceBase.forward (backward). Or we can not do that for static method.

I will merge this PR first and we can address later! Great contribution!!

@ByronHsu ByronHsu merged commit 6b2fd02 into main Nov 14, 2024
3 checks passed
@ByronHsu ByronHsu deleted the shisahni/orpo branch November 14, 2024 04:12
pramodith added a commit that referenced this pull request Nov 14, 2024
## Summary
This PR refactors the `LigerFusedLinearPreferenceBase` class to contain
an abstractmethod corresponding to the calculation of the loss that
needs to be implemented by all sub-classes.

It also adds a new function to the class called `_compute_loss` which is
mostly the same as the `_compute_orpo_loss` function introduced in #362
but makes it generic to calculate the NLL/Cross Entropy Loss plus
accepts a custom loss function that implements a new alignment loss
function.

Most RLHF/RLAIF/Alignment algorithms state their final loss as `NLL +
Beta * (Alignment_Loss) `so adding the NLL logic inside the base class
reduces repeated code.

The _compute_loss function accepts

## Testing Done

On A100-80G-SXM


- Hardware Type: <BLANK>
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [X] run `make test-convergence` to ensure convergence

---------

Co-authored-by: pramodith <[email protected]>
ByronHsu pushed a commit that referenced this pull request Nov 15, 2024
## Summary

Add support for a fused, torch-compiled, and chunked DPO ([Direct
Preference Optimization](https://arxiv.org/html/2305.18290v3)) loss
kernel, as requested in
#371.
This implementation is largely based on the excellent work done on ORPO
(#362) by @shivam15s.

### DPO Loss Formulation

In a reference setting (not reference free):

$$r_\theta(x,y_c) - r_\theta(x,y_r) = \log(\pi_\theta(y_c|x)) -
\log(\pi_\theta(y_r|x))$$

$$-\log(\sigma((\log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x)) -
\log(\pi_{\theta_{\text{ref}}}(y_c|x)) +
\log(\pi_{\theta_{\text{ref}}}(y_r|x)))/\beta))$$

Corresponds to:
```python
# Policy model log probabilities
policy_chosen_logps = log_probs(policy_chosen_logits)
policy_rejected_logps = log_probs(policy_rejected_logits)

# Reference model log probabilities
ref_chosen_logps = log_probs(ref_chosen_logits)
ref_rejected_logps = log_probs(ref_rejected_logits)

# Compute advantages
chosen_advantages = policy_chosen_logps - ref_chosen_logps
rejected_advantages = policy_rejected_logps - ref_rejected_logps

# DPO loss
logits_diff = (chosen_advantages - rejected_advantages) / beta
losses = -F.logsigmoid(logits_diff)
```

In this PR:

1. The above mathematical equation shows that to maximize the reward
difference, we get formula:
    $$r_θ(x_c) - r_θ(x_r)$$
2. This can be further optimized using just:
    $$-log(σ((π_θ(x_c) - π_θ(x_r))/β))$$
3. So, the code implements:
    ```python
logits_diff = (chosen_logps - rejected_logps) / beta # (π_θ(x_c) -
π_θ(x_r))/β
losses = -F.logsigmoid(logits_diff) # -log(σ(logits_diff))
    ```
4. Sum up DPO and NLL:
    $$L_{DPO+NLL} = L_{DPO}+αL_{NLL}$$

## Testing Done


![dpo_loss_memory](https://github.com/user-attachments/assets/d48965a2-bab7-4a81-9872-a43826106731)

![dpo_loss_speed](https://github.com/user-attachments/assets/10ab33c3-a905-435f-886b-67c911b8fff6)


- Hardware Type: **NVIDIA L40S (48G)**
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [X] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <[email protected]>
Co-authored-by: shivam15s <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants