Skip to content

Commit 30b83ae

Browse files
vaibhavjindalkashifqgallouedec
authored
[Liger] Liger KTO support (huggingface#2812)
Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 3684b8d commit 30b83ae

File tree

4 files changed

+310
-82
lines changed

4 files changed

+310
-82
lines changed

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@
8383
"diffusers": ["diffusers>=0.18.0"],
8484
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
8585
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
86-
# can be set to >=0.5.3 when https://github.com/linkedin/Liger-Kernel/issues/586 is fixed
87-
"liger": ["liger-kernel==0.5.3; sys_platform != 'win32'"],
86+
"liger": ["liger-kernel>=0.5.5; sys_platform == 'Linux'"],
8887
"mergekit": ["mergekit>=0.0.5.1"],
8988
"peft": ["peft>=0.8.0"],
9089
"quantization": ["bitsandbytes"],

tests/test_kto_trainer.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from datasets import load_dataset
2020
from parameterized import parameterized
2121
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
22-
from transformers.testing_utils import require_peft
22+
from transformers.testing_utils import require_liger_kernel, require_peft
2323

2424
from trl import KTOConfig, KTOTrainer
2525
from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize
@@ -379,6 +379,38 @@ def test_kto_lora_save(self):
379379
except OSError:
380380
self.fail("Loading the saved peft adapter failed")
381381

382+
@require_liger_kernel
383+
def test_kto_trainer_with_liger(self):
384+
"""Test KTO trainer with Liger loss enabled."""
385+
with tempfile.TemporaryDirectory() as tmp_dir:
386+
training_args = KTOConfig(
387+
output_dir=tmp_dir,
388+
report_to="none",
389+
use_liger_loss=True, # Enable Liger loss
390+
)
391+
392+
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")
393+
394+
trainer = KTOTrainer(
395+
model=self.model,
396+
args=training_args,
397+
processing_class=self.tokenizer,
398+
train_dataset=dummy_dataset["train"],
399+
)
400+
401+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
402+
403+
trainer.train()
404+
405+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
406+
407+
# check the params have changed
408+
for n, param in previous_trainable_params.items():
409+
new_param = trainer.model.get_parameter(n)
410+
# check the params have changed - ignore 0 biases
411+
if param.sum() != 0:
412+
self.assertFalse(torch.equal(param, new_param))
413+
382414
def test_compute_metrics(self):
383415
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
384416
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")

trl/trainer/kto_config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ class KTOConfig(TrainingArguments):
7878
Number of processes to use for processing the dataset.
7979
disable_dropout (`bool`, *optional*, defaults to `True`):
8080
Whether to disable dropout in the model and reference model.
81+
use_liger_loss (`bool`, *optional*, defaults to `False`):
82+
Whether to use Liger loss. It requires liger-kernel to be installed.
83+
base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
84+
Name of the attribute in the model that contains the base model. This is used to get the base model from
85+
the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
8186
"""
8287

8388
learning_rate: float = field(
@@ -193,3 +198,15 @@ class KTOConfig(TrainingArguments):
193198
default=None,
194199
metadata={"help": "Number of processes to use for processing the dataset."},
195200
)
201+
use_liger_loss: bool = field(
202+
default=False,
203+
metadata={"help": "Whether to use Liger loss. It requires liger-kernel to be installed."},
204+
)
205+
base_model_attribute_name: str = field(
206+
default="model",
207+
metadata={
208+
"help": "Name of the attribute in the model that contains the base model. This is used to get the base "
209+
"model from the model when the model does not have a `get_decoder` method in the case when "
210+
"`use_liger_loss` is `True`."
211+
},
212+
)

0 commit comments

Comments
 (0)