Skip to content

Commit 116ee1a

Browse files
committed
checkstyle
1 parent 58a428a commit 116ee1a

File tree

4 files changed

+74
-25
lines changed

4 files changed

+74
-25
lines changed

src/liger_kernel/chunked_loss/dpo_loss.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
1010

1111
@staticmethod
12-
def preference_loss_fn(chosen_logps, rejected_logps, ref_chosen_logps=None, ref_rejected_logps=None, beta=0.1):
12+
def preference_loss_fn(
13+
chosen_logps,
14+
rejected_logps,
15+
ref_chosen_logps=None,
16+
ref_rejected_logps=None,
17+
beta=0.1,
18+
):
1319
"""
1420
Compute DPO loss (Direct Preference Optimization).
1521
Args:
@@ -102,7 +108,9 @@ def __init__(
102108
self.compiled = compiled
103109
self.use_ref_model = use_ref_model
104110

105-
def forward(self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None):
111+
def forward(
112+
self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
113+
):
106114
return LigerFusedLinearDPOFunction.apply(
107115
_input,
108116
lin_weight,

src/liger_kernel/chunked_loss/fused_linear_preference.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
1919
raise NotImplementedError("Preference loss function must be implemented.")
2020

2121
@staticmethod
22-
def get_ref_logps(input_chunk, ref_weight, target_chunk, ref_bias=None, ignore_index=-100):
22+
def get_ref_logps(
23+
input_chunk, ref_weight, target_chunk, ref_bias=None, ignore_index=-100
24+
):
2325
with torch.no_grad():
2426
ref_logits_chunk = input_chunk @ ref_weight.t()
2527
if ref_bias is not None:
@@ -29,11 +31,15 @@ def get_ref_logps(input_chunk, ref_weight, target_chunk, ref_bias=None, ignore_i
2931
loss_mask = target_chunk != ignore_index
3032
label_chunk = torch.where(loss_mask, target_chunk, 0)
3133

32-
ref_per_token_logps = ref_log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
33-
ref_average_log_prob = (ref_per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
34+
ref_per_token_logps = ref_log_probs_chunk.gather(
35+
-1, label_chunk.unsqueeze(-1)
36+
).squeeze(-1)
37+
ref_average_log_prob = (ref_per_token_logps * loss_mask).sum(
38+
-1
39+
) / loss_mask.sum(-1)
3440

35-
ref_chosen_logps = ref_average_log_prob[:input_chunk.shape[0] // 2]
36-
ref_rejected_logps = ref_average_log_prob[input_chunk.shape[0] // 2:]
41+
ref_chosen_logps = ref_average_log_prob[: input_chunk.shape[0] // 2]
42+
ref_rejected_logps = ref_average_log_prob[input_chunk.shape[0] // 2 :]
3743
return ref_chosen_logps, ref_rejected_logps
3844

3945
@staticmethod
@@ -242,8 +248,14 @@ def _compute_loss(
242248
rejected_logps = average_log_prob[len_chosen_chunk:]
243249

244250
if use_ref_model:
245-
ref_chosen_logps, ref_rejected_logps = LigerFusedLinearPreferenceBase.get_ref_logps(
246-
input_chunk, ref_weight, target_chunk, ref_bias=ref_bias, ignore_index=ignore_index
251+
ref_chosen_logps, ref_rejected_logps = (
252+
LigerFusedLinearPreferenceBase.get_ref_logps(
253+
input_chunk,
254+
ref_weight,
255+
target_chunk,
256+
ref_bias=ref_bias,
257+
ignore_index=ignore_index,
258+
)
247259
)
248260
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
249261
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps

test/chunked_loss/test_dpo_loss.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ class HFDPOLoss(HFAlignmentLoss):
1919
Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py
2020
"""
2121

22-
def __init__(self, ignore_index: int = -100, beta: float = 0.1, use_ref_model: bool = True):
23-
super().__init__(beta=beta, ignore_index=ignore_index, use_ref_model=use_ref_model)
22+
def __init__(
23+
self, ignore_index: int = -100, beta: float = 0.1, use_ref_model: bool = True
24+
):
25+
super().__init__(
26+
beta=beta, ignore_index=ignore_index, use_ref_model=use_ref_model
27+
)
2428

2529
def alignment_loss(
2630
self,
@@ -69,7 +73,9 @@ def __init__(
6973
).get_batch_loss_metrics
7074

7175
def forward(self, x, y):
72-
return self.dpo_loss(self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias)
76+
return self.dpo_loss(
77+
self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias
78+
)
7379

7480

7581
class LigerLMHeadDPO(torch.nn.Module):
@@ -90,10 +96,14 @@ def __init__(
9096
self.ref_lin = torch.nn.Linear(
9197
in_features=H, out_features=V, bias=ref_bias, dtype=dtype
9298
)
93-
self.dpo_loss = LigerFusedLinearDPOLoss(ignore_index=ignore_index, beta=beta, use_ref_model=True)
99+
self.dpo_loss = LigerFusedLinearDPOLoss(
100+
ignore_index=ignore_index, beta=beta, use_ref_model=True
101+
)
94102

95103
def forward(self, x, y):
96-
return self.dpo_loss(self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias)
104+
return self.dpo_loss(
105+
self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias
106+
)
97107

98108

99109
@pytest.mark.parametrize(
@@ -113,7 +123,9 @@ def forward(self, x, y):
113123
@pytest.mark.parametrize("bias", [True, False])
114124
@pytest.mark.parametrize("ref_bias", [True, False])
115125
@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)])
116-
def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, ignore_index, beta):
126+
def test_correctness(
127+
B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, ignore_index, beta
128+
):
117129
B = 2 * B # dpo loss requires B to be even
118130

119131
torch_lm_head_dpo = TorchLMHeadDPO(
@@ -138,17 +150,17 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, igno
138150
torch_lm_head_dpo.lin.weight.data = liger_lm_head_dpo.lin.weight.data = torch.randn(
139151
V, H, device="cuda", dtype=dtype
140152
)
141-
torch_lm_head_dpo.ref_lin.weight.data = liger_lm_head_dpo.ref_lin.weight.data = torch.randn(
142-
V, H, device="cuda", dtype=dtype
153+
torch_lm_head_dpo.ref_lin.weight.data = liger_lm_head_dpo.ref_lin.weight.data = (
154+
torch.randn(V, H, device="cuda", dtype=dtype)
143155
)
144156

145157
if bias:
146158
torch_lm_head_dpo.lin.bias.data = liger_lm_head_dpo.lin.bias.data = torch.randn(
147159
V, device="cuda", dtype=dtype
148160
)
149161
if ref_bias:
150-
torch_lm_head_dpo.ref_lin.bias.data = liger_lm_head_dpo.ref_lin.bias.data = torch.randn(
151-
V, device="cuda", dtype=dtype
162+
torch_lm_head_dpo.ref_lin.bias.data = liger_lm_head_dpo.ref_lin.bias.data = (
163+
torch.randn(V, device="cuda", dtype=dtype)
152164
)
153165

154166
_input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar
@@ -244,8 +256,12 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref
244256
ref_bias1 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None
245257
ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None
246258

247-
loss1 = LigerFusedLinearDPOFunction.apply(input1, weight1, target, bias1, ref_weight1, ref_bias1)
248-
loss2 = liger_fused_linear_dpo(input2, weight2, target, bias2, ref_weight2, ref_bias2)
259+
loss1 = LigerFusedLinearDPOFunction.apply(
260+
input1, weight1, target, bias1, ref_weight1, ref_bias1
261+
)
262+
loss2 = liger_fused_linear_dpo(
263+
input2, weight2, target, bias2, ref_weight2, ref_bias2
264+
)
249265

250266
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
251267

test/utils.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,13 @@ def revert_liger_kernel_to_phi3(model_config: MiniModelConfig):
355355

356356
class HFAlignmentLoss:
357357

358-
def __init__(self, alpha: float = 1.0, beta: float = 0.1, ignore_index: int = -100, use_ref_model: bool = False):
358+
def __init__(
359+
self,
360+
alpha: float = 1.0,
361+
beta: float = 0.1,
362+
ignore_index: int = -100,
363+
use_ref_model: bool = False,
364+
):
359365
self.alpha = alpha
360366
self.beta = beta
361367
self.ignore_index = ignore_index
@@ -414,8 +420,13 @@ def get_ref_logps(
414420
ref_logits = _input @ ref_weight.t()
415421
if ref_bias is not None:
416422
ref_logits = ref_logits + ref_bias
417-
ref_all_logps = self.get_batch_logps(ref_logits, target, average_log_prob=average_log_prob)
418-
return ref_all_logps[:_input.shape[0] // 2], ref_all_logps[_input.shape[0] // 2:]
423+
ref_all_logps = self.get_batch_logps(
424+
ref_logits, target, average_log_prob=average_log_prob
425+
)
426+
return (
427+
ref_all_logps[: _input.shape[0] // 2],
428+
ref_all_logps[_input.shape[0] // 2 :],
429+
)
419430

420431
def concatenated_forward(
421432
self,
@@ -503,7 +514,9 @@ def get_batch_loss_metrics(
503514
)
504515
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
505516
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
506-
losses = self.alignment_loss(policy_chosen_logps, policy_rejected_logps, **loss_kwargs)
517+
losses = self.alignment_loss(
518+
policy_chosen_logps, policy_rejected_logps, **loss_kwargs
519+
)
507520
# full loss
508521
loss = policy_nll_loss * self.alpha - losses.mean()
509522
return loss

0 commit comments

Comments
 (0)