1
+ import os
2
+ import sys
3
+
1
4
import torch
2
5
import triton
3
6
from utils import (
11
14
12
15
from liger_kernel .chunked_loss .dpo_loss import LigerFusedLinearDPOFunction
13
16
17
+ sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../.." )))
18
+
14
19
15
20
class TorchDPOLoss (torch .nn .Module ):
16
21
def __init__ (
@@ -30,12 +35,14 @@ def __init__(
30
35
)
31
36
self .dpo_loss = HF_DPO_Loss (beta = beta , ignore_index = ignore_index )
32
37
33
- def forward (self , x , target ):
38
+ def forward (self , x , target , ref_chosen_logps , ref_rejected_logps ):
34
39
return self .dpo_loss .get_batch_loss_metrics (
35
40
x ,
36
41
self .lin .weight ,
37
42
target ,
38
43
self .lin .bias if hasattr (self .lin , "bias" ) else None ,
44
+ ref_chosen_logps ,
45
+ ref_rejected_logps ,
39
46
)
40
47
41
48
@@ -51,20 +58,13 @@ def __init__(
51
58
):
52
59
super ().__init__ ()
53
60
self .lin = torch .nn .Linear (
54
- in_features = H , out_features = V , bias = bias , dtype = dtype
61
+ in_features = H , out_features = V , bias = False , dtype = dtype
55
62
)
56
- self .beta = beta
57
- self .ignore_index = ignore_index
63
+ self .dpo_loss = LigerFusedLinearDPOFunction .apply
58
64
59
- def forward (self , x , target ):
60
- return LigerFusedLinearDPOFunction .apply (
61
- x ,
62
- self .lin .weight ,
63
- target ,
64
- self .lin .bias if hasattr (self .lin , "bias" ) else None ,
65
- self .ignore_index ,
66
- self .beta ,
67
- True ,
65
+ def forward (self , x , y , ref_chosen_logps , ref_rejected_logps ):
66
+ return self .dpo_loss (
67
+ x , self .lin .weight , y , ref_chosen_logps , ref_rejected_logps
68
68
)
69
69
70
70
@@ -91,6 +91,9 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
91
91
_input = torch .randn (B , T , H , device = device , dtype = dtype )
92
92
# Target shape: [B, T]
93
93
target = torch .randint (V , (B , T ), dtype = torch .long , device = device )
94
+
95
+ ref_chosen_logps = torch .randn (B // 2 , device = device )
96
+ ref_rejected_logps = torch .randn (B // 2 , device = device )
94
97
95
98
# Add ignore_index tokens to simulate padding
96
99
num_elements_to_assign = torch .randint (1 , B * T // 2 , (1 ,)).item ()
@@ -99,9 +102,9 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
99
102
100
103
def fwd ():
101
104
if provider == "liger" :
102
- return liger_dpo_loss (_input , target )
105
+ return liger_dpo_loss (_input , target , ref_chosen_logps , ref_rejected_logps )
103
106
elif provider == "huggingface" :
104
- return torch_dpo_loss (_input , target )
107
+ return torch_dpo_loss (_input , target , ref_chosen_logps , ref_rejected_logps )
105
108
106
109
def full ():
107
110
y = fwd ()
@@ -137,20 +140,22 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
137
140
138
141
# Input shape: [B, T, H]
139
142
_input = torch .randn (B , T , H , device = device , dtype = dtype )
140
-
141
143
# Target shape: [B, T]
142
144
target = torch .randint (V , (B , T ), device = device , dtype = torch .long )
143
145
146
+ ref_chosen_logps = torch .randn (B // 2 , device = device )
147
+ ref_rejected_logps = torch .randn (B // 2 , device = device )
148
+
144
149
# Add ignore_index tokens
145
150
num_elements_to_assign = torch .randint (1 , B * T // 2 , (1 ,)).item ()
146
151
indices_to_assign = torch .randperm (B * T )[:num_elements_to_assign ]
147
152
target .view (- 1 )[indices_to_assign ] = ignore_index
148
153
149
154
def fwd ():
150
155
if provider == "liger" :
151
- return liger_dpo_loss (_input , target )
156
+ return liger_dpo_loss (_input , target , ref_chosen_logps , ref_rejected_logps )
152
157
elif provider == "huggingface" :
153
- return torch_dpo_loss (_input , target )
158
+ return torch_dpo_loss (_input , target , ref_chosen_logps , ref_rejected_logps )
154
159
155
160
if mode == "forward" :
156
161
ms_50 , ms_20 , ms_80 = triton .testing .do_bench (
0 commit comments