Skip to content
17 changes: 9 additions & 8 deletions test/convergence/test_mini_models_multimodal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import functools
import os

import pytest
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers.models.auto.processing_auto import AutoProcessor

from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
from test.utils import (
UNTOKENIZED_DATASET_PATH,
MiniModelConfig,
Expand All @@ -10,14 +18,6 @@
supports_bfloat16,
)

import pytest
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers.models.auto.processing_auto import AutoProcessor

from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl

try:
# Qwen2-VL is only available in transformers>4.44.2
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig
Expand Down Expand Up @@ -140,6 +140,7 @@ def preprocess_function(examples):
padding="max_length",
truncation=True,
max_length=1024, # longer than for text-only b/c images require quite a few tokens
return_tensors="pt",
Copy link
Contributor Author

@tyler-romero tyler-romero Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix for the broken Qwen2VL tests - when the transformers version was bumped from 4.44 -> 4.45 it seems like the behavior here changed - started returning np arrays.

)

train_dataset = (
Expand Down
94 changes: 71 additions & 23 deletions test/transformers/test_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
LigerSwiGLUMLP,
monkey_patch,
)
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import (
MODEL_TYPE_TO_APPLY_LIGER_FN,
_apply_liger_kernel,
_apply_liger_kernel_to_instance,
)
from test.utils import revert_liger_kernel_to_qwen2_vl


def test_import_from_root():
Expand Down Expand Up @@ -81,14 +83,16 @@ def dummy_apply_liger_kernal_to_llama(

with patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}):
mock_llama.__signature__ = apply_liger_kernal_to_llama_sig
_apply_liger_kernel(
"llama",
rope=False,
fused_linear_cross_entropy=False,
cross_entropy=True,
foobar=True,
barbaz=False,
),
(
_apply_liger_kernel(
"llama",
rope=False,
fused_linear_cross_entropy=False,
cross_entropy=True,
foobar=True,
barbaz=False,
),
)
mock_llama.assert_called_once()
mock_llama.assert_called_once_with(
rope=False,
Expand Down Expand Up @@ -150,14 +154,16 @@ def dummy_apply_liger_kernel_to_llama(

with patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}):
mock_llama.__signature__ = apply_liger_kernel_to_llama_sig
_apply_liger_kernel_to_instance(
model=mock_llama_model_instance,
rope=False,
fused_linear_cross_entropy=False,
cross_entropy=True,
foobar=True,
barbaz=False,
),
(
_apply_liger_kernel_to_instance(
model=mock_llama_model_instance,
rope=False,
fused_linear_cross_entropy=False,
cross_entropy=True,
foobar=True,
barbaz=False,
),
)
mock_llama.assert_called_once()
mock_llama.assert_called_once_with(
model=mock_llama_model_instance,
Expand Down Expand Up @@ -199,7 +205,6 @@ def test_patching_apis_support_patching_model_instance():
def test_apply_liger_kernel_to_instance_for_llama():
# Ensure any monkey patching is cleaned up for subsequent tests
with patch("transformers.models.llama.modeling_llama"):

# Instantiate a dummy model
config = transformers.models.llama.configuration_llama.LlamaConfig(
torch_dtype=torch.bfloat16,
Expand Down Expand Up @@ -232,7 +237,6 @@ def test_apply_liger_kernel_to_instance_for_llama():
def test_apply_liger_kernel_to_instance_for_mistral():
# Ensure any monkey patching is cleaned up for subsequent tests
with patch("transformers.models.mistral.modeling_mistral"):

# Instantiate a dummy model
config = transformers.models.mistral.configuration_mistral.MistralConfig(
torch_dtype=torch.bfloat16,
Expand Down Expand Up @@ -265,7 +269,6 @@ def test_apply_liger_kernel_to_instance_for_mistral():
def test_apply_liger_kernel_to_instance_for_mixtral():
# Ensure any monkey patching is cleaned up for subsequent tests
with patch("transformers.models.mixtral.modeling_mixtral"):

# Instantiate a dummy model
config = transformers.models.mixtral.configuration_mixtral.MixtralConfig(
torch_dtype=torch.bfloat16,
Expand Down Expand Up @@ -302,7 +305,6 @@ def test_apply_liger_kernel_to_instance_for_mixtral():
def test_apply_liger_kernel_to_instance_for_gemma():
# Ensure any monkey patching is cleaned up for subsequent tests
with patch("transformers.models.gemma.modeling_gemma"):

# Instantiate a dummy model
config = transformers.models.gemma.configuration_gemma.GemmaConfig(
torch_dtype=torch.bfloat16,
Expand Down Expand Up @@ -335,7 +337,6 @@ def test_apply_liger_kernel_to_instance_for_gemma():
def test_apply_liger_kernel_to_instance_for_gemma2():
# Ensure any monkey patching is cleaned up for subsequent tests
with patch("transformers.models.gemma2.modeling_gemma2"):

# Instantiate a dummy model
config = transformers.models.gemma2.configuration_gemma2.Gemma2Config(
torch_dtype=torch.bfloat16,
Expand Down Expand Up @@ -372,7 +373,6 @@ def test_apply_liger_kernel_to_instance_for_gemma2():
def test_apply_liger_kernel_to_instance_for_qwen2():
# Ensure any monkey patching is cleaned up for subsequent tests
with patch("transformers.models.qwen2.modeling_qwen2"):

# Instantiate a dummy model
config = transformers.models.qwen2.configuration_qwen2.Qwen2Config(
torch_dtype=torch.bfloat16,
Expand Down Expand Up @@ -402,10 +402,58 @@ def test_apply_liger_kernel_to_instance_for_qwen2():
assert isinstance(layer.post_attention_layernorm, LigerRMSNorm)


def test_apply_liger_kernel_to_instance_for_qwen2_vl():
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLForConditionalGeneration,
)

# Instantiate a dummy model
config = transformers.models.qwen2_vl.configuration_qwen2_vl.Qwen2VLConfig(
torch_dtype=torch.bfloat16,
rms_norm_eps=1e-5,
hidden_size=64,
intermediate_size=64,
embed_dim=32,
hidden_act="silu",
num_hidden_layers=2,
)
dummy_model_instance = Qwen2VLForConditionalGeneration._from_config(config)

assert isinstance(dummy_model_instance, Qwen2VLForConditionalGeneration)

# Check that model instance variables are not yet patched with Liger modules
assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm)
for layer in dummy_model_instance.model.layers:
assert not isinstance(layer.mlp, LigerSwiGLUMLP)
assert not isinstance(layer.input_layernorm, LigerRMSNorm)
assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm)
for vision_block in dummy_model_instance.visual.blocks:
assert not isinstance(vision_block.norm1, LigerLayerNorm)
assert not isinstance(vision_block.norm2, LigerLayerNorm)

# Test applying kernels to the model instance
_apply_liger_kernel_to_instance(model=dummy_model_instance)

# Check that the model's instance variables were correctly patched with Liger modules
assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm)
assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm)
for layer in dummy_model_instance.model.layers:
assert isinstance(layer.mlp, LigerSwiGLUMLP)
assert isinstance(layer.input_layernorm, LigerRMSNorm)
assert isinstance(layer.post_attention_layernorm, LigerRMSNorm)
for vision_block in dummy_model_instance.visual.blocks:
assert isinstance(vision_block.norm1, LigerLayerNorm)
assert isinstance(vision_block.norm2, LigerLayerNorm)

# Ensure any monkey patching is cleaned up for subsequent tests
# Using `with patch("transformers.models.qwen2_vl.modeling_qwen2_vl")` does not
# work heres, due to the way `modeling_qwen2_vl` is used in the monkey patch fn.
revert_liger_kernel_to_qwen2_vl()


def test_apply_liger_kernel_to_instance_for_phi3():
# Ensure any monkey patching is cleaned up for subsequent tests
with patch("transformers.models.phi3.modeling_phi3"):

# Instantiate a dummy model
config = transformers.models.phi3.configuration_phi3.Phi3Config(
torch_dtype=torch.bfloat16,
Expand Down
Loading