Skip to content

Commit 22a79a2

Browse files
upgrade liger to 0.4.0 (#1973)
* upgrade liger to 0.3.1 * update docs and example * skip duplicate code check * Update src/axolotl/integrations/liger/args.py Co-authored-by: NanoCode012 <[email protected]> * Update README.md Co-authored-by: NanoCode012 <[email protected]> * add logging * chore: lint * add test case * upgrade liger and transformers * also upgrade accelerate * use kwargs to support patch release * make sure prepared path is empty for test * use transfromers 4.46.1 since 4.46.2 breaks fsdp --------- Co-authored-by: NanoCode012 <[email protected]>
1 parent b79badc commit 22a79a2

File tree

11 files changed

+146
-119
lines changed

11 files changed

+146
-119
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,8 @@ plugins:
562562
- axolotl.integrations.liger.LigerPlugin
563563
liger_rope: true
564564
liger_rms_norm: true
565-
liger_swiglu: true
565+
liger_glu_activation: true
566+
liger_layer_norm: true
566567
liger_fused_linear_cross_entropy: true
567568
```
568569

examples/deepseek-v2/qlora-fsdp-2_5.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ strict: false
99
plugins:
1010
- axolotl.integrations.liger.LigerPlugin
1111
liger_rms_norm: true
12-
liger_swiglu: true
12+
liger_glu_activation: true
1313
liger_fused_linear_cross_entropy: true
1414

1515
chat_template: deepseek_v2

examples/llama-3/fft-8b-liger-fsdp.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ plugins:
44
- axolotl.integrations.liger.LigerPlugin
55
liger_rope: true
66
liger_rms_norm: true
7-
liger_swiglu: true
7+
liger_glu_activation: true
88
liger_fused_linear_cross_entropy: true
99

1010
strict: false

requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
22
packaging==23.2
33
peft==0.13.2
4-
transformers==4.46.0
4+
transformers==4.46.1
55
tokenizers>=0.20.1
66
bitsandbytes==0.44.1
7-
accelerate==1.0.1
7+
accelerate==1.1.0
88
datasets==3.0.1
99
deepspeed==0.15.3
1010
pydantic==2.6.3
@@ -34,7 +34,7 @@ tensorboard
3434
python-dotenv==1.0.1
3535
autoawq>=0.2.5
3636
triton>=2.3.0
37-
liger-kernel==0.3.0
37+
liger-kernel==0.4.0
3838

3939
mamba-ssm==1.2.0.post1
4040

src/axolotl/core/trainer_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -896,13 +896,13 @@ def store_metrics(
896896
for key, value in metrics.items():
897897
self._stored_metrics[train_eval][key].append(value)
898898

899-
def _save_checkpoint(self, model, trial, metrics=None):
899+
def _save_checkpoint(self, model, trial, **kwargs):
900900
# make sure the checkpoint dir exists, since trainer is flakey
901901
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
902902
run_dir = self._get_output_dir(trial=trial)
903903
output_dir = os.path.join(run_dir, checkpoint_folder)
904904
os.makedirs(output_dir, exist_ok=True)
905-
return super()._save_checkpoint(model, trial, metrics=metrics)
905+
return super()._save_checkpoint(model, trial, **kwargs)
906906

907907

908908
class AxolotlMambaTrainer(AxolotlTrainer):

src/axolotl/integrations/liger/__init__.py

Lines changed: 31 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,23 @@
1818
Liger Kernel is the collection of Triton-native kernels for LLM Training.
1919
It is designed to be performant, correct, and light-weight.
2020
"""
21+
import inspect
2122
import logging
2223
import sys
23-
from functools import partial
2424

2525
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
26-
from liger_kernel.transformers.geglu import LigerGEGLUMLP
26+
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
2727
from liger_kernel.transformers.rms_norm import LigerRMSNorm
2828
from liger_kernel.transformers.rope import liger_rotary_pos_emb
2929
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
3030

3131
from axolotl.integrations.base import BasePlugin
3232

33+
from ...utils.distributed import zero_only
3334
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
3435

36+
LOG = logging.getLogger("axolotl.integrations.liger")
37+
3538

3639
class LigerPlugin(BasePlugin):
3740
"""
@@ -42,59 +45,31 @@ def get_input_args(self):
4245
return "axolotl.integrations.liger.LigerArgs"
4346

4447
def pre_model_load(self, cfg):
45-
if cfg.model_config_type == "llama":
46-
from liger_kernel.transformers.model.llama import (
47-
lce_forward as llama_lce_forward,
48-
)
49-
from transformers.models.llama import modeling_llama
50-
51-
if cfg.liger_rope:
52-
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
53-
if cfg.liger_rms_norm:
54-
modeling_llama.LlamaRMSNorm = LigerRMSNorm
55-
if cfg.liger_swiglu:
56-
modeling_llama.LlamaMLP = LigerSwiGLUMLP
57-
if cfg.liger_cross_entropy:
58-
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
59-
elif cfg.liger_fused_linear_cross_entropy:
60-
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
61-
62-
elif cfg.model_config_type == "mistral":
63-
from liger_kernel.transformers.model.mistral import (
64-
lce_forward as mistral_lce_forward,
65-
)
66-
from transformers.models.mistral import modeling_mistral
67-
68-
if cfg.liger_rope:
69-
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
70-
if cfg.liger_rms_norm:
71-
modeling_mistral.MistralRMSNorm = LigerRMSNorm
72-
if cfg.liger_swiglu:
73-
modeling_mistral.MistralMLP = LigerSwiGLUMLP
74-
if cfg.liger_cross_entropy:
75-
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
76-
if cfg.liger_fused_linear_cross_entropy:
77-
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
78-
79-
elif cfg.model_config_type == "gemma":
80-
from liger_kernel.transformers.model.gemma import (
81-
lce_forward as gemma_lce_forward,
82-
)
83-
from transformers.models.gemma import modeling_gemma
84-
85-
if cfg.liger_rope:
86-
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
87-
if cfg.liger_rms_norm:
88-
modeling_gemma.GemmaRMSNorm = partial(
89-
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
48+
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
49+
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
50+
liger_fn_sig = inspect.signature(apply_liger_fn)
51+
kwargs = {}
52+
if "rope" in liger_fn_sig.parameters:
53+
kwargs["rope"] = cfg.liger_rope
54+
if "cross_entropy" in liger_fn_sig.parameters:
55+
kwargs["cross_entropy"] = cfg.liger_cross_entropy
56+
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
57+
kwargs[
58+
"fused_linear_cross_entropy"
59+
] = cfg.liger_fused_linear_cross_entropy
60+
if "rms_norm" in liger_fn_sig.parameters:
61+
kwargs["rms_norm"] = cfg.liger_rms_norm
62+
if "layer_norm" in liger_fn_sig.parameters:
63+
kwargs["layer_norm"] = cfg.liger_layer_norm
64+
if "geglu" in liger_fn_sig.parameters:
65+
kwargs["geglu"] = cfg.liger_glu_activation
66+
elif "swiglu" in liger_fn_sig.parameters:
67+
kwargs["swiglu"] = cfg.liger_glu_activation
68+
with zero_only():
69+
LOG.info(
70+
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
9071
)
91-
if cfg.liger_swiglu:
92-
modeling_gemma.GemmaMLP = LigerGEGLUMLP
93-
if cfg.liger_cross_entropy:
94-
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
95-
if cfg.liger_fused_linear_cross_entropy:
96-
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
97-
72+
apply_liger_fn(**kwargs)
9873
elif cfg.model_config_type == "jamba":
9974
from transformers.models.jamba import modeling_jamba
10075

@@ -104,30 +79,12 @@ def pre_model_load(self, cfg):
10479
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
10580
if cfg.liger_rms_norm:
10681
modeling_jamba.JambaRMSNorm = LigerRMSNorm
107-
if cfg.liger_swiglu:
82+
if cfg.liger_glu_activation:
10883
modeling_jamba.JambaMLP = LigerSwiGLUMLP
10984
if cfg.liger_cross_entropy:
11085
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
11186
if cfg.liger_fused_linear_cross_entropy:
11287
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
113-
114-
elif cfg.model_config_type == "qwen2":
115-
from liger_kernel.transformers.model.qwen2 import (
116-
lce_forward as qwen2_lce_forward,
117-
)
118-
from transformers.models.qwen2 import modeling_qwen2
119-
120-
if cfg.liger_rope:
121-
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
122-
if cfg.liger_rms_norm:
123-
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
124-
if cfg.liger_swiglu:
125-
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
126-
if cfg.liger_cross_entropy:
127-
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
128-
if cfg.liger_fused_linear_cross_entropy:
129-
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
130-
13188
elif cfg.model_config_type == "deepseek_v2":
13289
from accelerate import init_empty_weights
13390
from transformers import AutoModelForCausalLM
@@ -146,44 +103,9 @@ def pre_model_load(self, cfg):
146103
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
147104
if cfg.liger_rms_norm:
148105
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
149-
if cfg.liger_swiglu:
106+
if cfg.liger_glu_activation:
150107
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
151108
if cfg.liger_cross_entropy:
152109
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
153110
if cfg.liger_fused_linear_cross_entropy:
154111
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
155-
156-
elif cfg.model_config_type == "gemma2":
157-
from transformers.models.gemma2 import modeling_gemma2
158-
159-
if cfg.liger_rope:
160-
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
161-
if cfg.liger_rms_norm:
162-
modeling_gemma2.Gemma2RMSNorm = partial(
163-
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
164-
)
165-
if cfg.liger_swiglu:
166-
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
167-
if cfg.liger_cross_entropy:
168-
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
169-
if cfg.liger_fused_linear_cross_entropy:
170-
logging.warning(
171-
"Fused linear cross entropy is not supported for Gemma 2."
172-
)
173-
174-
elif cfg.model_config_type == "phi3":
175-
from liger_kernel.transformers.model.phi3 import (
176-
lce_forward as phi3_lce_forward,
177-
)
178-
from transformers.models.phi3 import modeling_phi3
179-
180-
if cfg.liger_rope:
181-
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
182-
if cfg.liger_rms_norm:
183-
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
184-
if cfg.liger_swiglu:
185-
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
186-
if cfg.liger_cross_entropy:
187-
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
188-
if cfg.liger_fused_linear_cross_entropy:
189-
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

src/axolotl/integrations/liger/args.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
"""
1616
Module for handling LIGER input arguments.
1717
"""
18+
import logging
1819
from typing import Optional
1920

20-
from pydantic import BaseModel
21+
from pydantic import BaseModel, model_validator
22+
23+
LOG = logging.getLogger("axolotl.integrations.liger.args")
2124

2225

2326
class LigerArgs(BaseModel):
@@ -27,6 +30,24 @@ class LigerArgs(BaseModel):
2730

2831
liger_rope: Optional[bool] = None
2932
liger_rms_norm: Optional[bool] = None
33+
liger_layer_norm: Optional[bool] = None
3034
liger_swiglu: Optional[bool] = None
35+
liger_glu_activation: Optional[bool] = None
3136
liger_cross_entropy: Optional[bool] = None
3237
liger_fused_linear_cross_entropy: Optional[bool] = None
38+
39+
@model_validator(mode="before")
40+
@classmethod
41+
def check_deprecated_swiglu(cls, data):
42+
if data.get("liger_swiglu") is not None:
43+
if data.get("liger_glu_activation") is not None:
44+
raise ValueError(
45+
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
46+
)
47+
48+
LOG.warning(
49+
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
50+
"Please use 'liger_glu_activation' instead."
51+
)
52+
data["liger_glu_activation"] = data.pop("liger_swiglu")
53+
return data

tests/e2e/integrations/liger.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
Simple end-to-end test for Liger integration
33
"""
4-
54
import unittest
65
from pathlib import Path
76

tests/integrations/__init__.py

Whitespace-only changes.

tests/integrations/liger.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""
2+
config validation tests for swiglu args
3+
"""
4+
# pylint: disable=duplicate-code
5+
import logging
6+
from typing import Optional
7+
8+
import pytest
9+
10+
from axolotl.utils.config import validate_config
11+
from axolotl.utils.dict import DictDefault
12+
13+
14+
@pytest.fixture(name="minimal_base_cfg")
15+
def fixture_cfg():
16+
return DictDefault(
17+
{
18+
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
19+
"learning_rate": 0.000001,
20+
"datasets": [
21+
{
22+
"path": "mhenrichsen/alpaca_2k_test",
23+
"type": "alpaca",
24+
}
25+
],
26+
"micro_batch_size": 1,
27+
"gradient_accumulation_steps": 1,
28+
}
29+
)
30+
31+
32+
class BaseValidation:
33+
"""
34+
Base validation module to setup the log capture
35+
"""
36+
37+
_caplog: Optional[pytest.LogCaptureFixture] = None
38+
39+
@pytest.fixture(autouse=True)
40+
def inject_fixtures(self, caplog):
41+
self._caplog = caplog
42+
43+
44+
# pylint: disable=too-many-public-methods
45+
class TestValidation(BaseValidation):
46+
"""
47+
Test the validation module for liger
48+
"""
49+
50+
def test_deprecated_swiglu(self, minimal_cfg):
51+
test_cfg = DictDefault(
52+
{
53+
"liger_swiglu": False,
54+
}
55+
| minimal_cfg
56+
)
57+
58+
with self._caplog.at_level(logging.WARNING):
59+
updated_cfg = validate_config(test_cfg)
60+
assert (
61+
"The 'liger_swiglu' argument is deprecated"
62+
in self._caplog.records[0].message
63+
)
64+
assert updated_cfg.liger_swiglu is None
65+
assert updated_cfg.liger_glu_activations is False
66+
67+
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
68+
test_cfg = DictDefault(
69+
{
70+
"liger_swiglu": False,
71+
"liger_glu_activations": True,
72+
}
73+
| minimal_cfg
74+
)
75+
76+
with pytest.raises(
77+
ValueError,
78+
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
79+
):
80+
validate_config(test_cfg)

0 commit comments

Comments
 (0)