@@ -19,8 +19,12 @@ class HFDPOLoss(HFAlignmentLoss):
19
19
Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py
20
20
"""
21
21
22
- def __init__ (self , ignore_index : int = - 100 , beta : float = 0.1 , use_ref_model : bool = True ):
23
- super ().__init__ (beta = beta , ignore_index = ignore_index , use_ref_model = use_ref_model )
22
+ def __init__ (
23
+ self , ignore_index : int = - 100 , beta : float = 0.1 , use_ref_model : bool = True
24
+ ):
25
+ super ().__init__ (
26
+ beta = beta , ignore_index = ignore_index , use_ref_model = use_ref_model
27
+ )
24
28
25
29
def alignment_loss (
26
30
self ,
@@ -69,7 +73,9 @@ def __init__(
69
73
).get_batch_loss_metrics
70
74
71
75
def forward (self , x , y ):
72
- return self .dpo_loss (self .lin .weight , x , y , self .lin .bias , self .ref_lin .weight , self .ref_lin .bias )
76
+ return self .dpo_loss (
77
+ self .lin .weight , x , y , self .lin .bias , self .ref_lin .weight , self .ref_lin .bias
78
+ )
73
79
74
80
75
81
class LigerLMHeadDPO (torch .nn .Module ):
@@ -90,10 +96,14 @@ def __init__(
90
96
self .ref_lin = torch .nn .Linear (
91
97
in_features = H , out_features = V , bias = ref_bias , dtype = dtype
92
98
)
93
- self .dpo_loss = LigerFusedLinearDPOLoss (ignore_index = ignore_index , beta = beta , use_ref_model = True )
99
+ self .dpo_loss = LigerFusedLinearDPOLoss (
100
+ ignore_index = ignore_index , beta = beta , use_ref_model = True
101
+ )
94
102
95
103
def forward (self , x , y ):
96
- return self .dpo_loss (self .lin .weight , x , y , self .lin .bias , self .ref_lin .weight , self .ref_lin .bias )
104
+ return self .dpo_loss (
105
+ self .lin .weight , x , y , self .lin .bias , self .ref_lin .weight , self .ref_lin .bias
106
+ )
97
107
98
108
99
109
@pytest .mark .parametrize (
@@ -113,7 +123,9 @@ def forward(self, x, y):
113
123
@pytest .mark .parametrize ("bias" , [True , False ])
114
124
@pytest .mark .parametrize ("ref_bias" , [True , False ])
115
125
@pytest .mark .parametrize ("ignore_index, beta" , [(- 100 , 0.1 ), (42 , 0.2 )])
116
- def test_correctness (B , T , H , V , scalar , dtype , atol , rtol , bias , ref_bias , ignore_index , beta ):
126
+ def test_correctness (
127
+ B , T , H , V , scalar , dtype , atol , rtol , bias , ref_bias , ignore_index , beta
128
+ ):
117
129
B = 2 * B # dpo loss requires B to be even
118
130
119
131
torch_lm_head_dpo = TorchLMHeadDPO (
@@ -138,17 +150,17 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, igno
138
150
torch_lm_head_dpo .lin .weight .data = liger_lm_head_dpo .lin .weight .data = torch .randn (
139
151
V , H , device = "cuda" , dtype = dtype
140
152
)
141
- torch_lm_head_dpo .ref_lin .weight .data = liger_lm_head_dpo .ref_lin .weight .data = torch . randn (
142
- V , H , device = "cuda" , dtype = dtype
153
+ torch_lm_head_dpo .ref_lin .weight .data = liger_lm_head_dpo .ref_lin .weight .data = (
154
+ torch . randn ( V , H , device = "cuda" , dtype = dtype )
143
155
)
144
156
145
157
if bias :
146
158
torch_lm_head_dpo .lin .bias .data = liger_lm_head_dpo .lin .bias .data = torch .randn (
147
159
V , device = "cuda" , dtype = dtype
148
160
)
149
161
if ref_bias :
150
- torch_lm_head_dpo .ref_lin .bias .data = liger_lm_head_dpo .ref_lin .bias .data = torch . randn (
151
- V , device = "cuda" , dtype = dtype
162
+ torch_lm_head_dpo .ref_lin .bias .data = liger_lm_head_dpo .ref_lin .bias .data = (
163
+ torch . randn ( V , device = "cuda" , dtype = dtype )
152
164
)
153
165
154
166
_input = torch .randn (B , T , H , device = "cuda" , dtype = dtype ) * scalar
@@ -244,8 +256,12 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref
244
256
ref_bias1 = _ref_bias .detach ().clone ().requires_grad_ (True ) if ref_bias else None
245
257
ref_bias2 = _ref_bias .detach ().clone ().requires_grad_ (True ) if ref_bias else None
246
258
247
- loss1 = LigerFusedLinearDPOFunction .apply (input1 , weight1 , target , bias1 , ref_weight1 , ref_bias1 )
248
- loss2 = liger_fused_linear_dpo (input2 , weight2 , target , bias2 , ref_weight2 , ref_bias2 )
259
+ loss1 = LigerFusedLinearDPOFunction .apply (
260
+ input1 , weight1 , target , bias1 , ref_weight1 , ref_bias1
261
+ )
262
+ loss2 = liger_fused_linear_dpo (
263
+ input2 , weight2 , target , bias2 , ref_weight2 , ref_bias2
264
+ )
249
265
250
266
assert_verbose_allclose (loss1 , loss2 , atol = atol , rtol = rtol )
251
267
0 commit comments