Skip to content

Commit 6ab3b9f

Browse files
shivam15sShivam Sahni
andauthored
Monkey patch layer norm in mllama (#302)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Monkey patches layer norm in mllama for conditional generation <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> Tested monkey patching works as intended <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shivam Sahni <[email protected]>
1 parent 24a7efc commit 6ab3b9f

File tree

4 files changed

+66
-5
lines changed

4 files changed

+66
-5
lines changed

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def apply_liger_kernel_to_mllama(
121121
rope: bool = True,
122122
cross_entropy: bool = False,
123123
fused_linear_cross_entropy: bool = True,
124+
layer_norm: bool = True,
124125
rms_norm: bool = True,
125126
swiglu: bool = True,
126127
model: PreTrainedModel = None,
@@ -151,12 +152,15 @@ def apply_liger_kernel_to_mllama(
151152
MllamaForCausalLM,
152153
MllamaForConditionalGeneration,
153154
MllamaTextModel,
155+
MllamaVisionModel,
154156
)
155157

156158
from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
157159

158160
if rope:
159161
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
162+
if layer_norm:
163+
modeling_mllama.nn.LayerNorm = LigerLayerNorm
160164
if rms_norm:
161165
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
162166
if swiglu:
@@ -174,11 +178,14 @@ def apply_liger_kernel_to_mllama(
174178

175179
if isinstance(model, MllamaForConditionalGeneration):
176180
language_model: MllamaForCausalLM = model.language_model
181+
vision_model: MllamaVisionModel = model.vision_model
177182
text_model: MllamaTextModel = language_model.model
178183
elif isinstance(model, MllamaForCausalLM):
179184
text_model = model.model
185+
vision_model = None
180186
elif isinstance(model, MllamaTextModel):
181187
text_model = model
188+
vision_model = None
182189
else:
183190
raise ValueError(f"Unsupported Mllama model type: {type(model)}")
184191

@@ -194,6 +201,20 @@ def apply_liger_kernel_to_mllama(
194201
_patch_rms_norm_module(decoder_layer.input_layernorm)
195202
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
196203

204+
if vision_model:
205+
_patch_layer_norm_module(vision_model.layernorm_pre)
206+
_patch_layer_norm_module(vision_model.layernorm_post)
207+
208+
for layer in vision_model.transformer.layers:
209+
if layer_norm:
210+
_patch_layer_norm_module(layer.input_layernorm)
211+
_patch_layer_norm_module(layer.post_attention_layernorm)
212+
213+
for layer in vision_model.global_transformer.layers:
214+
if layer_norm:
215+
_patch_layer_norm_module(layer.input_layernorm)
216+
_patch_layer_norm_module(layer.post_attention_layernorm)
217+
197218

198219
def apply_liger_kernel_to_mistral(
199220
rope: bool = True,
@@ -767,7 +788,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
767788
for key, value in kwargs.items()
768789
if key in apply_fn_signature.parameters
769790
}
770-
771791
logger.info(
772792
f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
773793
)

test/convergence/test_mini_models_multimodal.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,15 +316,12 @@ def run_mini_model_multimodal(
316316
kwargs = {
317317
"rms_norm": True,
318318
"cross_entropy": True,
319+
"layer_norm": True,
319320
}
320321
model_supports_rope = "qwen2_vl" not in model_name
321322
if model_supports_rope:
322323
kwargs["rope"] = True
323324

324-
model_supports_layer_norm = "qwen2_vl" in model_name
325-
if model_supports_layer_norm:
326-
kwargs["layer_norm"] = True
327-
328325
if "gemma" in model_name:
329326
kwargs["geglu"] = True
330327
else:

test/transformers/test_monkey_patch.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,27 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation():
302302
layer.post_attention_layernorm.forward
303303
) != inspect.getsource(LigerRMSNorm.forward)
304304

305+
assert inspect.getsource(
306+
dummy_model_instance.vision_model.layernorm_pre.forward
307+
) != inspect.getsource(LigerLayerNorm.forward)
308+
assert inspect.getsource(
309+
dummy_model_instance.vision_model.layernorm_post.forward
310+
) != inspect.getsource(LigerLayerNorm.forward)
311+
for layer in dummy_model_instance.vision_model.transformer.layers:
312+
assert inspect.getsource(
313+
layer.input_layernorm.forward
314+
) != inspect.getsource(LigerLayerNorm.forward)
315+
assert inspect.getsource(
316+
layer.post_attention_layernorm.forward
317+
) != inspect.getsource(LigerLayerNorm.forward)
318+
for layer in dummy_model_instance.vision_model.global_transformer.layers:
319+
assert inspect.getsource(
320+
layer.input_layernorm.forward
321+
) != inspect.getsource(LigerLayerNorm.forward)
322+
assert inspect.getsource(
323+
layer.post_attention_layernorm.forward
324+
) != inspect.getsource(LigerLayerNorm.forward)
325+
305326
# Test applying kernels to the model instance
306327
_apply_liger_kernel_to_instance(model=dummy_model_instance)
307328

@@ -320,6 +341,27 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation():
320341
layer.post_attention_layernorm.forward
321342
) == inspect.getsource(LigerRMSNorm.forward)
322343

344+
assert inspect.getsource(
345+
dummy_model_instance.vision_model.layernorm_pre.forward
346+
) == inspect.getsource(LigerLayerNorm.forward)
347+
assert inspect.getsource(
348+
dummy_model_instance.vision_model.layernorm_post.forward
349+
) == inspect.getsource(LigerLayerNorm.forward)
350+
for layer in dummy_model_instance.vision_model.transformer.layers:
351+
assert inspect.getsource(
352+
layer.input_layernorm.forward
353+
) == inspect.getsource(LigerLayerNorm.forward)
354+
assert inspect.getsource(
355+
layer.post_attention_layernorm.forward
356+
) == inspect.getsource(LigerLayerNorm.forward)
357+
for layer in dummy_model_instance.vision_model.global_transformer.layers:
358+
assert inspect.getsource(
359+
layer.input_layernorm.forward
360+
) == inspect.getsource(LigerLayerNorm.forward)
361+
assert inspect.getsource(
362+
layer.post_attention_layernorm.forward
363+
) == inspect.getsource(LigerLayerNorm.forward)
364+
323365

324366
def test_apply_liger_kernel_to_instance_for_mllama_for_causal_lm():
325367
# Ensure any monkey patching is cleaned up for subsequent tests

test/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,10 @@ def revert_liger_kernel_to_mllama():
222222
Revert all Liger kernel patches applied to MLlama.
223223
"""
224224

225+
import torch.nn as nn
225226
from transformers.models.mllama import modeling_mllama
226227

228+
importlib.reload(nn)
227229
importlib.reload(modeling_mllama)
228230
print("Liger kernel patches have been reverted.")
229231

0 commit comments

Comments
 (0)