Skip to content

Commit a60d315

Browse files
committed
fix(transformers): update layer normalization patching in apply_liger_kernel_to_glm4v function
1 parent e7e61e6 commit a60d315

File tree

2 files changed

+8
-66
lines changed

2 files changed

+8
-66
lines changed

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,14 +1910,14 @@ def apply_liger_kernel_to_glm4v(
19101910
if vision_model is not None:
19111911
for vision_block in vision_model.blocks:
19121912
if rms_norm:
1913-
_patch_layer_norm_module(vision_block.norm1)
1914-
_patch_layer_norm_module(vision_block.norm2)
1913+
_patch_rms_norm_module(vision_block.norm1)
1914+
_patch_rms_norm_module(vision_block.norm2)
19151915
if swiglu:
19161916
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
19171917

19181918
if text_model is not None:
19191919
if rms_norm:
1920-
_patch_layer_norm_module(text_model.norm)
1920+
_patch_rms_norm_module(text_model.norm)
19211921
for decoder_layer in text_model.layers:
19221922
if swiglu:
19231923
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
@@ -1935,7 +1935,7 @@ def apply_liger_kernel_to_glm4v(
19351935
"gemma3_text": apply_liger_kernel_to_gemma3_text,
19361936
"gemma3": apply_liger_kernel_to_gemma3,
19371937
"glm4": apply_liger_kernel_to_glm4,
1938-
"glm4.1v": apply_liger_kernel_to_glm4v,
1938+
"glm4v": apply_liger_kernel_to_glm4v,
19391939
"llama": apply_liger_kernel_to_llama,
19401940
"llama4_text": apply_liger_kernel_to_llama4,
19411941
"llama4": apply_liger_kernel_to_llama4,

test/transformers/test_monkey_patch.py

Lines changed: 4 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,60 +1655,6 @@ def test_apply_liger_kernel_to_instance_for_glm4():
16551655
except Exception as e:
16561656
pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
16571657

1658-
1659-
@pytest.mark.skipif(not is_glm4v_available(), reason="glm4v module not available")
1660-
def test_apply_liger_kernel_to_instance_for_glm4v():
1661-
# Ensure any monkey patching is cleaned up for subsequent tests
1662-
with patch("transformers.models.glm4v.modeling_glm4v"):
1663-
from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
1664-
1665-
# Instantiate a dummy model
1666-
config = transformers.models.glm4v.configuration_glm4v.Glm4vConfig(
1667-
torch_dtype=torch.bfloat16,
1668-
text_config={
1669-
"num_hidden_layers": 2,
1670-
"rms_norm_eps": 1e-5,
1671-
"hidden_size": 32,
1672-
"intermediate_size": 64,
1673-
"hidden_act": "silu",
1674-
},
1675-
vision_config={
1676-
"num_hidden_layers": 2,
1677-
"layer_norm_eps": 1e-5,
1678-
"hidden_size": 48,
1679-
"intermediate_size": 64,
1680-
},
1681-
)
1682-
dummy_model_instance = AutoModelForCausalLM.from_config(config)
1683-
1684-
# Check that model instance variables are not yet patched with Liger modules
1685-
assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(glm4v_lce_forward)
1686-
assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward)
1687-
for layer in dummy_model_instance.model.layers:
1688-
assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerPhi3SwiGLUMLP.forward)
1689-
assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
1690-
assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
1691-
assert inspect.getsource(layer.post_self_attn_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
1692-
assert inspect.getsource(layer.post_mlp_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
1693-
1694-
# Test applying kernels to the model instance
1695-
_apply_liger_kernel_to_instance(model=dummy_model_instance)
1696-
1697-
# Check that the model's instance variables were correctly patched with Liger modules
1698-
assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(glm4v_lce_forward)
1699-
assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward)
1700-
for layer in dummy_model_instance.model.layers:
1701-
assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerPhi3SwiGLUMLP.forward)
1702-
assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)
1703-
assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)
1704-
assert inspect.getsource(layer.post_self_attn_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)
1705-
assert inspect.getsource(layer.post_mlp_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)
1706-
try:
1707-
print(dummy_model_instance)
1708-
except Exception as e:
1709-
pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
1710-
1711-
17121658
@pytest.mark.skipif(not is_glm4v_available(), reason="glm4v module not available")
17131659
def test_apply_liger_kernel_to_instance_for_glm4v():
17141660
# Ensure any monkey patching is cleaned up for subsequent tests
@@ -1729,7 +1675,7 @@ def test_apply_liger_kernel_to_instance_for_glm4v():
17291675
},
17301676
vision_config={
17311677
"num_hidden_layers": 2,
1732-
"layer_norm_eps": 1e-5,
1678+
"rms_norm_eps": 1e-5,
17331679
"hidden_size": 48,
17341680
"intermediate_size": 64,
17351681
},
@@ -1739,9 +1685,7 @@ def test_apply_liger_kernel_to_instance_for_glm4v():
17391685

17401686
# Check that model instance variables are not yet patched with Liger modules
17411687
assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(glm4v_lce_forward)
1742-
assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource(
1743-
LigerRMSNorm.forward
1744-
)
1688+
assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource(LigerRMSNorm.forward)
17451689
for layer in dummy_model_instance.language_model.layers:
17461690
assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerPhi3SwiGLUMLP.forward)
17471691
assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
@@ -1757,10 +1701,8 @@ def test_apply_liger_kernel_to_instance_for_glm4v():
17571701
_apply_liger_kernel_to_instance(model=dummy_model_instance)
17581702

17591703
# Check that the model's instance variables were correctly patched with Liger modules
1760-
assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(glm4v_lce_forward)
1761-
assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource(
1762-
LigerRMSNorm.forward
1763-
)
1704+
assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(glm4v_lce_forward)
1705+
assert inspect.getsource(dummy_model_instance.language_model.norm.forward) == inspect.getsource(LigerRMSNorm.forward)
17641706
for layer in dummy_model_instance.language_model.layers:
17651707
assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerPhi3SwiGLUMLP.forward)
17661708
assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)

0 commit comments

Comments
 (0)