Skip to content

Commit ad9d9c9

Browse files
sergiopaniegokashifalbertvillanova
authored
Remove liger loss in favor of liger kernel (huggingface#4364)
Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Albert Villanova del Moral <[email protected]>
1 parent 095544e commit ad9d9c9

File tree

12 files changed

+118
-49
lines changed

12 files changed

+118
-49
lines changed

docs/source/reducing_memory_usage.md

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,30 +102,48 @@ For more information, see [Liger Kernel Integration](liger_kernel_integration).
102102
To use Liger for reducing peak memory usage, use the following code snippet:
103103

104104
<hfoptions id="liger">
105+
<hfoption id="SFT">
106+
107+
```python
108+
from trl import SFTConfig
109+
110+
training_args = SFTConfig(..., use_liger_kernel=True)
111+
```
112+
113+
</hfoption>
105114
<hfoption id="DPO">
106-
115+
107116
```python
108117
from trl import DPOConfig
109118

110-
training_args = DPOConfig(..., use_liger_loss=True)
119+
training_args = DPOConfig(..., use_liger_kernel=True)
111120
```
112121

113122
</hfoption>
114123
<hfoption id="GRPO">
115-
124+
116125
```python
117126
from trl import GRPOConfig
118127

119-
training_args = GRPOConfig(..., use_liger_loss=True)
128+
training_args = GRPOConfig(..., use_liger_kernel=True)
120129
```
121130

122131
</hfoption>
123132
<hfoption id="KTO">
124-
133+
125134
```python
126135
from trl import KTOConfig
127136

128-
training_args = KTOConfig(..., use_liger_loss=True)
137+
training_args = KTOConfig(..., use_liger_kernel=True)
138+
```
139+
140+
</hfoption>
141+
<hfoption id="GKD">
142+
143+
```python
144+
from trl import GKDConfig
145+
146+
training_args = GKDConfig(..., use_liger_kernel=True)
129147
```
130148

131149
</hfoption>

tests/slow/test_grpo_slow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ def teardown_method(self):
6767

6868
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
6969
@require_liger_kernel
70-
def test_training_with_liger_grpo_loss(self, model_name):
70+
def test_training_with_liger_grpo_kernel(self, model_name):
7171
training_args = GRPOConfig(
7272
output_dir=self.tmp_dir,
7373
per_device_train_batch_size=3,
7474
num_generations=3,
75-
use_liger_loss=True,
75+
use_liger_kernel=True,
7676
max_completion_length=self.max_length,
7777
report_to="none",
7878
logging_strategy="no",
@@ -108,14 +108,14 @@ def test_training_with_liger_grpo_loss(self, model_name):
108108
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
109109
@require_liger_kernel
110110
@require_peft
111-
def test_training_with_liger_grpo_loss_and_peft(self, model_name):
111+
def test_training_with_liger_grpo_kernel_and_peft(self, model_name):
112112
from peft import LoraConfig, TaskType
113113

114114
training_args = GRPOConfig(
115115
output_dir=self.tmp_dir,
116116
per_device_train_batch_size=3,
117117
num_generations=3,
118-
use_liger_loss=True,
118+
use_liger_kernel=True,
119119
max_completion_length=self.max_length,
120120
report_to="none",
121121
logging_strategy="no",

tests/test_dpo_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def test_train_encoder_decoder_liger(self):
250250
per_device_train_batch_size=2,
251251
learning_rate=9e-1,
252252
report_to="none",
253-
use_liger_loss=True,
253+
use_liger_kernel=True,
254254
)
255255
trainer = DPOTrainer(
256256
model=model,
@@ -1330,7 +1330,7 @@ def test_dpo_trainer_with_liger(self, beta, loss_type):
13301330
learning_rate=9e-1,
13311331
eval_strategy="steps",
13321332
beta=beta,
1333-
use_liger_loss=True, # Enable Liger loss
1333+
use_liger_kernel=True, # Enable Liger kernel
13341334
loss_type=loss_type,
13351335
report_to="none",
13361336
)

tests/test_gkd_trainer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import os
1616

17+
import pytest
1718
import torch
1819
import torch.nn.functional as F
1920
from datasets import load_dataset
@@ -29,9 +30,10 @@ class TestGKDTrainerGenerateOnPolicy(TrlTestCase):
2930
@classmethod
3031
def setup_class(cls):
3132
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
33+
cls.device = "cuda" if torch.cuda.is_available() else "cpu"
3234
cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
3335
cls.tokenizer.pad_token = cls.tokenizer.eos_token
34-
cls.model = AutoModelForCausalLM.from_pretrained(model_id)
36+
cls.model = AutoModelForCausalLM.from_pretrained(model_id).to(cls.device)
3537
cls.generation_config = GenerationConfig(
3638
max_new_tokens=20,
3739
num_return_sequences=1,
@@ -44,8 +46,8 @@ def test_generate_on_policy_outputs_deterministic(self):
4446
tokenized_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True)
4547

4648
inputs = {
47-
"prompts": tokenized_prompts["input_ids"],
48-
"prompt_attention_mask": tokenized_prompts["attention_mask"],
49+
"prompts": tokenized_prompts["input_ids"].to(self.device),
50+
"prompt_attention_mask": tokenized_prompts["attention_mask"].to(self.device),
4951
}
5052

5153
# Set temperature to 0 for deterministic output
@@ -91,8 +93,8 @@ def test_generate_on_policy_outputs(self):
9193
tokenized_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True)
9294

9395
inputs = {
94-
"prompts": tokenized_prompts["input_ids"],
95-
"attention_mask": tokenized_prompts["attention_mask"],
96+
"prompts": tokenized_prompts["input_ids"].to(self.device),
97+
"attention_mask": tokenized_prompts["attention_mask"].to(self.device),
9698
}
9799

98100
outputs = GKDTrainer.generate_on_policy_outputs(
@@ -238,6 +240,7 @@ def test_gkd_trainer(self):
238240
assert "model.safetensors" in os.listdir(self.tmp_dir + "/checkpoint-2")
239241

240242
@require_liger_kernel
243+
@pytest.mark.xfail(reason="Computing the Liger loss spikes GPU memory usage, causing the test to run OOM.")
241244
def test_gkd_trainer_with_liger(self):
242245
training_args = GKDConfig(
243246
output_dir=self.tmp_dir,

tests/test_grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1460,7 +1460,7 @@ def reward_func(completions, **kwargs):
14601460
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
14611461
num_generations=3, # reduce the number of generations to reduce memory usage
14621462
max_completion_length=8, # reduce the completion length to reduce memory usage
1463-
use_liger_loss=True, # enable Liger loss
1463+
use_liger_kernel=True, # enable Liger kernel
14641464
loss_type="bnpo", # default dapo is not supported yet
14651465
report_to="none",
14661466
)

tests/test_kto_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,11 @@ def test_kto_lora_save(self):
361361

362362
@require_liger_kernel
363363
def test_kto_trainer_with_liger(self):
364-
"""Test KTO trainer with Liger loss enabled."""
364+
"""Test KTO trainer with Liger kernel enabled."""
365365
training_args = KTOConfig(
366366
output_dir=self.tmp_dir,
367367
report_to="none",
368-
use_liger_loss=True, # Enable Liger loss
368+
use_liger_kernel=True, # Enable Liger kernel
369369
)
370370

371371
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")

trl/trainer/dpo_config.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
1516
from dataclasses import dataclass, field
1617
from enum import Enum
1718
from typing import Any, Callable, Optional, Union
@@ -156,11 +157,17 @@ class DPOConfig(TrainingArguments):
156157
[MPO](https://huggingface.co/papers/2411.10442)). The `loss_weights` parameter can be used to specify
157158
corresponding weights for each loss type.
158159
159-
use_liger_loss (`bool`, *optional*, defaults to `False`):
160+
use_liger_loss (`bool`, *optional*, defaults to `None`):
160161
Whether to use Liger loss.
162+
163+
<Deprecated version="0.25.0">
164+
165+
Parameter `use_liger_loss` is deprecated and will be removed in version 0.28.0. Use `use_liger_kernel` instead.
166+
167+
</Deprecated>
161168
base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
162169
Name of the attribute in the model that contains the base model. This is used to get the base model from
163-
the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
170+
the model when the model does not have a `get_decoder` method in the case when `use_liger_kernel` is `True`.
164171
beta (`float`, *optional*, defaults to `0.1`):
165172
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
166173
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
@@ -378,15 +385,15 @@ class DPOConfig(TrainingArguments):
378385
},
379386
)
380387
use_liger_loss: bool = field(
381-
default=False,
388+
default=None,
382389
metadata={"help": "Whether to use Liger loss."},
383390
)
384391
base_model_attribute_name: str = field(
385392
default="model",
386393
metadata={
387394
"help": "Name of the attribute in the model that contains the base model. This is used to get the base "
388395
"model from the model when the model does not have a `get_decoder` method in the case when "
389-
"`use_liger_loss` is `True`."
396+
"`use_liger_kernel` is `True`."
390397
},
391398
)
392399
beta: float = field(
@@ -510,4 +517,13 @@ def __post_init__(self):
510517
f"Length of loss_weights list ({self.loss_weights}) must match number of loss types "
511518
f"({loss_types})."
512519
)
520+
521+
if self.use_liger_loss:
522+
warnings.warn(
523+
"The `use_liger_loss` argument is deprecated and will be removed in version 0.28.0. Please use "
524+
"`use_liger_kernel` instead.",
525+
FutureWarning,
526+
stacklevel=2,
527+
)
528+
self.use_liger_kernel = self.use_liger_loss
513529
super().__post_init__()

trl/trainer/dpo_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,15 +378,15 @@ def __init__(
378378
disable_dropout_in_model(self.ref_model)
379379

380380
# Liger kernel
381-
if args.use_liger_loss:
381+
if args.use_liger_kernel:
382382
if not is_liger_kernel_available():
383383
raise ImportError(
384-
"You set `use_liger_loss=True` but the liger kernel is not available. "
384+
"You set `use_liger_kernel=True` but the liger kernel is not available. "
385385
"Please install liger-kernel first: `pip install liger-kernel`"
386386
)
387387
if args.loss_type not in ["sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"]:
388388
raise ValueError(
389-
"You set `use_liger_loss=True` but the loss type is not from `[sigmoid, apo_zero, apo_down, sppo_hard, nca_pair`. "
389+
"You set `use_liger_kernel=True` but the loss type is not from `[sigmoid, apo_zero, apo_down, sppo_hard, nca_pair`. "
390390
"Please set `loss_type='[sigmoid | apo_zero | apo_down | sppo_hard | nca_pair]'` to use the liger kernel."
391391
)
392392
self.dpo_loss_fn = LigerFusedLinearDPOLoss(
@@ -1730,7 +1730,7 @@ def get_batch_loss_metrics(
17301730
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
17311731
metrics = {}
17321732

1733-
if self.args.use_liger_loss:
1733+
if self.args.use_liger_kernel:
17341734
model_output = self._compute_loss_liger(model, batch)
17351735
losses = model_output["loss"]
17361736
chosen_rewards = model_output["chosen_rewards"]

trl/trainer/grpo_config.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
1516
from dataclasses import dataclass, field
1617
from typing import Optional, Union
1718

@@ -220,8 +221,14 @@ class GRPOConfig(TrainingArguments):
220221
position, improving results. Range: `[0.0-1.0]`. A value of `0.0` masks all but the highest entropy token;
221222
`1.0` keeps all tokens. The paper recommends a value of `0.2`. If used with
222223
`mask_truncated_completions=True`, only tokens from non-truncated completions are considered.
223-
use_liger_loss (`bool`, *optional*, defaults to `False`):
224-
Whether to use the Liger GRPO loss.
224+
use_liger_loss (`bool`, *optional*, defaults to `None`):
225+
Whether to use Liger loss.
226+
227+
<Deprecated version="0.25.0">
228+
229+
Parameter `use_liger_loss` is deprecated and will be removed in version 0.28.0. Use `use_liger_kernel` instead.
230+
231+
</Deprecated>
225232
vllm_importance_sampling_correction (`bool`, *optional*, defaults to `True`):
226233
Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and recomputed
227234
logprobs. [Your Efficient RL Framework Secretly Brings You Off-Policy RL
@@ -605,7 +612,7 @@ class GRPOConfig(TrainingArguments):
605612
},
606613
)
607614
use_liger_loss: bool = field(
608-
default=False,
615+
default=None,
609616
metadata={"help": "Whether to use the Liger GRPO loss."},
610617
)
611618
vllm_importance_sampling_correction: bool = field(
@@ -697,5 +704,14 @@ def __post_init__(self):
697704
f"{self.num_generations}, which is less than the minimum required."
698705
)
699706

700-
if self.delta is not None and self.use_liger_loss:
701-
raise ValueError("Liger loss does not support two-sided GRPO loss yet.")
707+
if self.use_liger_loss:
708+
warnings.warn(
709+
"The `use_liger_loss` argument is deprecated and will be removed in version 0.28.0. Please use "
710+
"`use_liger_kernel` instead.",
711+
FutureWarning,
712+
stacklevel=2,
713+
)
714+
self.use_liger_kernel = self.use_liger_loss
715+
716+
if self.delta is not None and self.use_liger_kernel:
717+
raise ValueError("Liger kernel does not support two-sided GRPO loss yet.")

trl/trainer/grpo_trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -390,17 +390,17 @@ def __init__(
390390
self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
391391
self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction
392392
self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap
393-
self.use_liger_loss = args.use_liger_loss
393+
self.use_liger_kernel = args.use_liger_kernel
394394
self.loss_type = args.loss_type
395395
self.scale_rewards = args.scale_rewards
396396
self.importance_sampling_level = args.importance_sampling_level
397397
self.mask_truncated_completions = args.mask_truncated_completions
398398
self.top_entropy_quantile = args.top_entropy_quantile
399-
if self.use_liger_loss and self.top_entropy_quantile < 1.0:
399+
if self.use_liger_kernel and self.top_entropy_quantile < 1.0:
400400
raise NotImplementedError(
401401
"Liger Kernels don't currently support masking token positions based on entropy."
402402
)
403-
if self.use_liger_loss and not self.importance_sampling_level == "token":
403+
if self.use_liger_kernel and not self.importance_sampling_level == "token":
404404
raise NotImplementedError(
405405
"Liger Kernels currently only support token-level importance sampling. Please set"
406406
"`importance_sampling_level` to 'token'."
@@ -478,10 +478,10 @@ def __init__(
478478
disable_dropout_in_model(self.ref_model)
479479

480480
# Liger loss
481-
if self.use_liger_loss:
481+
if self.use_liger_kernel:
482482
if not is_liger_kernel_available():
483483
raise ImportError(
484-
"Liger is required to use `liger_loss` as the GRPO loss. Run `pip install liger-kernel`."
484+
"Liger is required to use `use_liger_kernel` as the GRPO loss. Run `pip install liger-kernel`."
485485
)
486486
# redirect the model.module forward to the model forward to ensure pre-forward hooks are called
487487
self._forward_redirection = _ForwardRedirection()
@@ -1720,7 +1720,7 @@ def compute_liger_loss(self, unwrapped_model, inputs):
17201720
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
17211721
if return_outputs:
17221722
raise ValueError("The GRPOTrainer does not support returning outputs")
1723-
if self.use_liger_loss:
1723+
if self.use_liger_kernel:
17241724
# Compute the loss using the liger grpo loss
17251725
unwrapped_model = self.accelerator.unwrap_model(model)
17261726
return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)

0 commit comments

Comments
 (0)