18
18
Liger Kernel is the collection of Triton-native kernels for LLM Training.
19
19
It is designed to be performant, correct, and light-weight.
20
20
"""
21
+ import inspect
21
22
import logging
22
23
import sys
23
- from functools import partial
24
24
25
25
from liger_kernel .transformers .cross_entropy import LigerCrossEntropyLoss
26
- from liger_kernel .transformers .geglu import LigerGEGLUMLP
26
+ from liger_kernel .transformers .monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
27
27
from liger_kernel .transformers .rms_norm import LigerRMSNorm
28
28
from liger_kernel .transformers .rope import liger_rotary_pos_emb
29
29
from liger_kernel .transformers .swiglu import LigerSwiGLUMLP
30
30
31
31
from axolotl .integrations .base import BasePlugin
32
32
33
+ from ...utils .distributed import zero_only
33
34
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
34
35
36
+ LOG = logging .getLogger ("axolotl.integrations.liger" )
37
+
35
38
36
39
class LigerPlugin (BasePlugin ):
37
40
"""
@@ -42,59 +45,31 @@ def get_input_args(self):
42
45
return "axolotl.integrations.liger.LigerArgs"
43
46
44
47
def pre_model_load (self , cfg ):
45
- if cfg .model_config_type == "llama" :
46
- from liger_kernel .transformers .model .llama import (
47
- lce_forward as llama_lce_forward ,
48
- )
49
- from transformers .models .llama import modeling_llama
50
-
51
- if cfg .liger_rope :
52
- modeling_llama .apply_rotary_pos_emb = liger_rotary_pos_emb
53
- if cfg .liger_rms_norm :
54
- modeling_llama .LlamaRMSNorm = LigerRMSNorm
55
- if cfg .liger_swiglu :
56
- modeling_llama .LlamaMLP = LigerSwiGLUMLP
57
- if cfg .liger_cross_entropy :
58
- modeling_llama .CrossEntropyLoss = LigerCrossEntropyLoss
59
- elif cfg .liger_fused_linear_cross_entropy :
60
- modeling_llama .LlamaForCausalLM .forward = llama_lce_forward
61
-
62
- elif cfg .model_config_type == "mistral" :
63
- from liger_kernel .transformers .model .mistral import (
64
- lce_forward as mistral_lce_forward ,
65
- )
66
- from transformers .models .mistral import modeling_mistral
67
-
68
- if cfg .liger_rope :
69
- modeling_mistral .apply_rotary_pos_emb = liger_rotary_pos_emb
70
- if cfg .liger_rms_norm :
71
- modeling_mistral .MistralRMSNorm = LigerRMSNorm
72
- if cfg .liger_swiglu :
73
- modeling_mistral .MistralMLP = LigerSwiGLUMLP
74
- if cfg .liger_cross_entropy :
75
- modeling_mistral .CrossEntropyLoss = LigerCrossEntropyLoss
76
- if cfg .liger_fused_linear_cross_entropy :
77
- modeling_mistral .MistralForCausalLM .forward = mistral_lce_forward
78
-
79
- elif cfg .model_config_type == "gemma" :
80
- from liger_kernel .transformers .model .gemma import (
81
- lce_forward as gemma_lce_forward ,
82
- )
83
- from transformers .models .gemma import modeling_gemma
84
-
85
- if cfg .liger_rope :
86
- modeling_gemma .apply_rotary_pos_emb = liger_rotary_pos_emb
87
- if cfg .liger_rms_norm :
88
- modeling_gemma .GemmaRMSNorm = partial (
89
- LigerRMSNorm , offset = 1.0 , init_fn = "zeros" , casting_mode = "gemma"
48
+ if cfg .model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN :
49
+ apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN [cfg .model_config_type ]
50
+ liger_fn_sig = inspect .signature (apply_liger_fn )
51
+ kwargs = {}
52
+ if "rope" in liger_fn_sig .parameters :
53
+ kwargs ["rope" ] = cfg .liger_rope
54
+ if "cross_entropy" in liger_fn_sig .parameters :
55
+ kwargs ["cross_entropy" ] = cfg .liger_cross_entropy
56
+ if "fused_linear_cross_entropy" in liger_fn_sig .parameters :
57
+ kwargs [
58
+ "fused_linear_cross_entropy"
59
+ ] = cfg .liger_fused_linear_cross_entropy
60
+ if "rms_norm" in liger_fn_sig .parameters :
61
+ kwargs ["rms_norm" ] = cfg .liger_rms_norm
62
+ if "layer_norm" in liger_fn_sig .parameters :
63
+ kwargs ["layer_norm" ] = cfg .liger_layer_norm
64
+ if "geglu" in liger_fn_sig .parameters :
65
+ kwargs ["geglu" ] = cfg .liger_glu_activation
66
+ elif "swiglu" in liger_fn_sig .parameters :
67
+ kwargs ["swiglu" ] = cfg .liger_glu_activation
68
+ with zero_only ():
69
+ LOG .info (
70
+ f"Applying LIGER to { cfg .model_config_type } with kwargs: { kwargs } "
90
71
)
91
- if cfg .liger_swiglu :
92
- modeling_gemma .GemmaMLP = LigerGEGLUMLP
93
- if cfg .liger_cross_entropy :
94
- modeling_gemma .CrossEntropyLoss = LigerCrossEntropyLoss
95
- if cfg .liger_fused_linear_cross_entropy :
96
- modeling_gemma .GemmaForCausalLM .forward = gemma_lce_forward
97
-
72
+ apply_liger_fn (** kwargs )
98
73
elif cfg .model_config_type == "jamba" :
99
74
from transformers .models .jamba import modeling_jamba
100
75
@@ -104,30 +79,12 @@ def pre_model_load(self, cfg):
104
79
modeling_jamba .apply_rotary_pos_emb = liger_rotary_pos_emb
105
80
if cfg .liger_rms_norm :
106
81
modeling_jamba .JambaRMSNorm = LigerRMSNorm
107
- if cfg .liger_swiglu :
82
+ if cfg .liger_glu_activation :
108
83
modeling_jamba .JambaMLP = LigerSwiGLUMLP
109
84
if cfg .liger_cross_entropy :
110
85
modeling_jamba .CrossEntropyLoss = LigerCrossEntropyLoss
111
86
if cfg .liger_fused_linear_cross_entropy :
112
87
modeling_jamba .JambaForCausalLM .forward = jamba_lce_forward
113
-
114
- elif cfg .model_config_type == "qwen2" :
115
- from liger_kernel .transformers .model .qwen2 import (
116
- lce_forward as qwen2_lce_forward ,
117
- )
118
- from transformers .models .qwen2 import modeling_qwen2
119
-
120
- if cfg .liger_rope :
121
- modeling_qwen2 .apply_rotary_pos_emb = liger_rotary_pos_emb
122
- if cfg .liger_rms_norm :
123
- modeling_qwen2 .Qwen2RMSNorm = LigerRMSNorm
124
- if cfg .liger_swiglu :
125
- modeling_qwen2 .Qwen2MLP = LigerSwiGLUMLP
126
- if cfg .liger_cross_entropy :
127
- modeling_qwen2 .CrossEntropyLoss = LigerCrossEntropyLoss
128
- if cfg .liger_fused_linear_cross_entropy :
129
- modeling_qwen2 .Qwen2ForCausalLM .forward = qwen2_lce_forward
130
-
131
88
elif cfg .model_config_type == "deepseek_v2" :
132
89
from accelerate import init_empty_weights
133
90
from transformers import AutoModelForCausalLM
@@ -146,44 +103,9 @@ def pre_model_load(self, cfg):
146
103
logging .warning ("Fused liger_rope is not supported for DeepseekV2." )
147
104
if cfg .liger_rms_norm :
148
105
modeling_mod .DeepseekV2RMSNorm = LigerRMSNorm
149
- if cfg .liger_swiglu :
106
+ if cfg .liger_glu_activation :
150
107
modeling_mod .DeepseekV2MLP .forward = LigerSwiGLUMLP .forward
151
108
if cfg .liger_cross_entropy :
152
109
modeling_mod .CrossEntropyLoss = LigerCrossEntropyLoss
153
110
if cfg .liger_fused_linear_cross_entropy :
154
111
modeling_mod .DeepseekV2ForCausalLM .forward = deepseekv2_lce_forward
155
-
156
- elif cfg .model_config_type == "gemma2" :
157
- from transformers .models .gemma2 import modeling_gemma2
158
-
159
- if cfg .liger_rope :
160
- modeling_gemma2 .apply_rotary_pos_emb = liger_rotary_pos_emb
161
- if cfg .liger_rms_norm :
162
- modeling_gemma2 .Gemma2RMSNorm = partial (
163
- LigerRMSNorm , offset = 1.0 , init_fn = "zeros" , casting_mode = "gemma"
164
- )
165
- if cfg .liger_swiglu :
166
- modeling_gemma2 .Gemma2MLP = LigerGEGLUMLP
167
- if cfg .liger_cross_entropy :
168
- modeling_gemma2 .CrossEntropyLoss = LigerCrossEntropyLoss
169
- if cfg .liger_fused_linear_cross_entropy :
170
- logging .warning (
171
- "Fused linear cross entropy is not supported for Gemma 2."
172
- )
173
-
174
- elif cfg .model_config_type == "phi3" :
175
- from liger_kernel .transformers .model .phi3 import (
176
- lce_forward as phi3_lce_forward ,
177
- )
178
- from transformers .models .phi3 import modeling_phi3
179
-
180
- if cfg .liger_rope :
181
- modeling_phi3 .apply_rotary_pos_emb = liger_rotary_pos_emb
182
- if cfg .liger_rms_norm :
183
- modeling_phi3 .Phi3RMSNorm = LigerRMSNorm
184
- if cfg .liger_swiglu :
185
- modeling_phi3 .Phi3MLP = LigerSwiGLUMLP
186
- if cfg .liger_cross_entropy :
187
- modeling_phi3 .CrossEntropyLoss = LigerCrossEntropyLoss
188
- if cfg .liger_fused_linear_cross_entropy :
189
- modeling_phi3 .Phi3ForCausalLM .forward = phi3_lce_forward
0 commit comments