|
19 | 19 | from datasets import load_dataset
|
20 | 20 | from parameterized import parameterized
|
21 | 21 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
22 |
| -from transformers.testing_utils import require_peft |
| 22 | +from transformers.testing_utils import require_liger_kernel, require_peft |
23 | 23 |
|
24 | 24 | from trl import KTOConfig, KTOTrainer
|
25 | 25 | from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize
|
@@ -379,6 +379,38 @@ def test_kto_lora_save(self):
|
379 | 379 | except OSError:
|
380 | 380 | self.fail("Loading the saved peft adapter failed")
|
381 | 381 |
|
| 382 | + @require_liger_kernel |
| 383 | + def test_kto_trainer_with_liger(self): |
| 384 | + """Test KTO trainer with Liger loss enabled.""" |
| 385 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 386 | + training_args = KTOConfig( |
| 387 | + output_dir=tmp_dir, |
| 388 | + report_to="none", |
| 389 | + use_liger_loss=True, # Enable Liger loss |
| 390 | + ) |
| 391 | + |
| 392 | + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") |
| 393 | + |
| 394 | + trainer = KTOTrainer( |
| 395 | + model=self.model, |
| 396 | + args=training_args, |
| 397 | + processing_class=self.tokenizer, |
| 398 | + train_dataset=dummy_dataset["train"], |
| 399 | + ) |
| 400 | + |
| 401 | + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
| 402 | + |
| 403 | + trainer.train() |
| 404 | + |
| 405 | + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
| 406 | + |
| 407 | + # check the params have changed |
| 408 | + for n, param in previous_trainable_params.items(): |
| 409 | + new_param = trainer.model.get_parameter(n) |
| 410 | + # check the params have changed - ignore 0 biases |
| 411 | + if param.sum() != 0: |
| 412 | + self.assertFalse(torch.equal(param, new_param)) |
| 413 | + |
382 | 414 | def test_compute_metrics(self):
|
383 | 415 | model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
384 | 416 | ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
|
0 commit comments