Skip to content

Commit 81c6411

Browse files
committed
Fix DPO benchmark by adding reference logps
Signed-off-by: Austin Liu <[email protected]>
1 parent e4d868e commit 81c6411

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

benchmark/scripts/benchmark_dpo_loss.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
import sys
3+
14
import torch
25
import triton
36
from utils import (
@@ -11,6 +14,8 @@
1114

1215
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
1316

17+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
18+
1419

1520
class TorchDPOLoss(torch.nn.Module):
1621
def __init__(
@@ -30,12 +35,14 @@ def __init__(
3035
)
3136
self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index)
3237

33-
def forward(self, x, target):
38+
def forward(self, x, target, ref_chosen_logps, ref_rejected_logps):
3439
return self.dpo_loss.get_batch_loss_metrics(
3540
x,
3641
self.lin.weight,
3742
target,
3843
self.lin.bias if hasattr(self.lin, "bias") else None,
44+
ref_chosen_logps,
45+
ref_rejected_logps,
3946
)
4047

4148

@@ -51,20 +58,13 @@ def __init__(
5158
):
5259
super().__init__()
5360
self.lin = torch.nn.Linear(
54-
in_features=H, out_features=V, bias=bias, dtype=dtype
61+
in_features=H, out_features=V, bias=False, dtype=dtype
5562
)
56-
self.beta = beta
57-
self.ignore_index = ignore_index
63+
self.dpo_loss = LigerFusedLinearDPOFunction.apply
5864

59-
def forward(self, x, target):
60-
return LigerFusedLinearDPOFunction.apply(
61-
x,
62-
self.lin.weight,
63-
target,
64-
self.lin.bias if hasattr(self.lin, "bias") else None,
65-
self.ignore_index,
66-
self.beta,
67-
True,
65+
def forward(self, x, y, ref_chosen_logps, ref_rejected_logps):
66+
return self.dpo_loss(
67+
x, self.lin.weight, y, ref_chosen_logps, ref_rejected_logps
6868
)
6969

7070

@@ -91,6 +91,9 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
9191
_input = torch.randn(B, T, H, device=device, dtype=dtype)
9292
# Target shape: [B, T]
9393
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
94+
95+
ref_chosen_logps = torch.randn(B // 2, device=device)
96+
ref_rejected_logps = torch.randn(B // 2, device=device)
9497

9598
# Add ignore_index tokens to simulate padding
9699
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
@@ -99,9 +102,9 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
99102

100103
def fwd():
101104
if provider == "liger":
102-
return liger_dpo_loss(_input, target)
105+
return liger_dpo_loss(_input, target, ref_chosen_logps, ref_rejected_logps)
103106
elif provider == "huggingface":
104-
return torch_dpo_loss(_input, target)
107+
return torch_dpo_loss(_input, target, ref_chosen_logps, ref_rejected_logps)
105108

106109
def full():
107110
y = fwd()
@@ -137,20 +140,22 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
137140

138141
# Input shape: [B, T, H]
139142
_input = torch.randn(B, T, H, device=device, dtype=dtype)
140-
141143
# Target shape: [B, T]
142144
target = torch.randint(V, (B, T), device=device, dtype=torch.long)
143145

146+
ref_chosen_logps = torch.randn(B // 2, device=device)
147+
ref_rejected_logps = torch.randn(B // 2, device=device)
148+
144149
# Add ignore_index tokens
145150
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
146151
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
147152
target.view(-1)[indices_to_assign] = ignore_index
148153

149154
def fwd():
150155
if provider == "liger":
151-
return liger_dpo_loss(_input, target)
156+
return liger_dpo_loss(_input, target, ref_chosen_logps, ref_rejected_logps)
152157
elif provider == "huggingface":
153-
return torch_dpo_loss(_input, target)
158+
return torch_dpo_loss(_input, target, ref_chosen_logps, ref_rejected_logps)
154159

155160
if mode == "forward":
156161
ms_50, ms_20, ms_80 = triton.testing.do_bench(

0 commit comments

Comments
 (0)