Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/liger_kernel/transformers/auto_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import inspect
import logging

from transformers import AutoConfig
from transformers import AutoModelForCausalLM

from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel

logger = logging.getLogger(__name__)


def _get_model_config(model_dir, **model_init_kwargs):
config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
Expand Down Expand Up @@ -36,3 +39,21 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}

return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)

@classmethod
def from_config(cls, config, **kwargs):
model_type = getattr(config, "model_type", None)
if not model_type:
logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
return
model_type = config.model_type

_apply_liger_kernel(model_type, **kwargs)

# Filter out kwargs that were passed to the apply_liger_* function, which will cause
# model initialization errors otherwise
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
apply_fn_signature = inspect.signature(apply_fn)
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}

return super().from_config(config, **applicable_kwargs)
35 changes: 35 additions & 0 deletions test/transformers/test_auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,38 @@ def test_auto_liger_kernel_for_causal_lm_from_pretrained():
pretrained_model_name_or_path, *model_args, **original_kwargs
)
assert model == "mock_model"


def test_auto_liger_kernel_for_causal_lm_from_config():
original_kwargs = {
"valid_arg_1": "some_value_1",
"valid_arg_2": 10,
}

# These args should be passed through to apply_liger_kernel_to_llama fn
apply_liger_kernel_kwargs = {
"rope": False,
"swiglu": True,
}

kwargs = {**original_kwargs, **apply_liger_kernel_kwargs}

# Mock the model config instance returned from AutoConfig.from_pretrained()
mock_model_config = MagicMock()
mock_model_config.model_type = "llama"
mock_llama = mock.Mock()

with (
patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}),
mock.patch.object(AutoModelForCausalLM, "from_config", return_value="mock_model") as mock_super_from_config,
):
# Mock the function signature of apply_liger_kernel_to_llama
mock_llama.__signature__ = signature(apply_liger_kernel_to_llama)

model = AutoLigerKernelForCausalLM.from_config(mock_model_config, **kwargs)

# Check that the apply_liger_kernel_to_llama mock was called with the correct kwargs
mock_llama.assert_called_once_with(rope=False, swiglu=True)
# Check that the original kwargs are passed to super().from_pretrained
mock_super_from_config.assert_called_once_with(mock_model_config, **original_kwargs)
assert model == "mock_model"
Loading