Skip to content

Commit 6d9fc11

Browse files
authored
[SFT] fix check for AutoLigerKernelForCausalLM (#2874)
* fix check for AutoLigerKernelForCausalLM * fix case where AutoLigerKernelForCausalLM is not defined * update min liger version * formatting * fix win CI
1 parent ffcb9f4 commit 6d9fc11

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
"diffusers": ["diffusers>=0.18.0"],
8686
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
8787
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
88-
"liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"],
88+
"liger": ["liger-kernel>=0.5.3; sys_platform != 'win32'"],
8989
"mergekit": ["mergekit>=0.0.5.1"],
9090
"peft": ["peft>=0.8.0"],
9191
"quantization": ["bitsandbytes"],

trl/trainer/sft_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060

6161
if is_liger_kernel_available():
6262
from liger_kernel.transformers import AutoLigerKernelForCausalLM
63+
else:
64+
AutoLigerKernelForCausalLM = None
6365

6466
if is_wandb_available():
6567
import wandb
@@ -440,7 +442,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
440442
)
441443

442444
# Compute token accuracy if we have labels and if the model is not using Liger (no logits)
443-
if "labels" in inputs and not self.args.use_liger:
445+
use_liger = self.args.use_liger or (
446+
AutoLigerKernelForCausalLM is not None and isinstance(model, AutoLigerKernelForCausalLM)
447+
)
448+
if "labels" in inputs and not use_liger:
444449
shift_logits = outputs.logits[..., :-1, :].contiguous()
445450
shift_labels = inputs["labels"][..., 1:].contiguous()
446451

0 commit comments

Comments
 (0)