Skip to content

Add glm4.1v model support #858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Aug 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4286dfe
feat(glm4v): implement lce_forward function with logits_to_keep param…
vvvdwbvvv Aug 15, 2025
98a2aad
feat(utils): add revert function for Glm4v kernel patches
vvvdwbvvv Aug 15, 2025
14f5c30
feat(glm4v): add support for Glm4v model in mini model setups
vvvdwbvvv Aug 15, 2025
b311601
feat(glm4v): add support for Glm4v model in mini model setups with ap…
vvvdwbvvv Aug 15, 2025
1d23e96
feat(utils): add revert function for Glm4v kernel patches
vvvdwbvvv Aug 15, 2025
fb8195e
feat(transformers): add Glm4v kernel application to monkey patch imports
vvvdwbvvv Aug 15, 2025
bd15a4b
feat(transformers): add Liger kernel application for GLM-4v models
vvvdwbvvv Aug 15, 2025
313575a
feat(transformers): add support for glm4.1v model in Liger kernel app…
vvvdwbvvv Aug 15, 2025
210a056
fix(transformers): update Glm4v MLP patch to use LigerPhi3SwiGLUMLP
vvvdwbvvv Aug 15, 2025
e3eb435
feat(transformers): add support for glm4v model in monkey patch tests
vvvdwbvvv Aug 15, 2025
6e60b90
fix(transformers): update imports for glm4v model in apply_liger_kern…
vvvdwbvvv Aug 17, 2025
98ea10c
fix(transformers): update import path for Glm4vConfig in test_apply_l…
vvvdwbvvv Aug 17, 2025
e7e61e6
feat(transformers): add support for glm4v model in monkey patch tests
vvvdwbvvv Aug 17, 2025
a60d315
fix(transformers): update layer normalization patching in apply_liger…
vvvdwbvvv Aug 17, 2025
0f89e54
fix(tests): clean up formatting in test_apply_liger_kernel_to_instanc…
vvvdwbvvv Aug 17, 2025
65a4cb0
feat(transformers): add support for apply_liger_kernel_to_glm4v function
vvvdwbvvv Aug 17, 2025
e9c82c8
fix(transformers): update import paths for Glm4vConfig and Glm4vForCo…
vvvdwbvvv Aug 17, 2025
d632e09
fix(tests): update import paths for Glm4vConfig and Glm4vForCondition…
vvvdwbvvv Aug 17, 2025
a127fe8
feat(transformers): add image and video token configurations to GLM4V…
vvvdwbvvv Aug 17, 2025
8e42906
fix: modify atol to pass test on mini model with logits
vvvdwbvvv Aug 17, 2025
6654c79
Merge branch 'main' into add-glm4.1v
lancerts Aug 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
Expand Down Expand Up @@ -89,6 +90,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_glm4v",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_llama",
"apply_liger_kernel_to_llava",
Expand Down Expand Up @@ -148,6 +150,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_glm4v",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_llama",
"apply_liger_kernel_to_llava",
Expand Down
150 changes: 150 additions & 0 deletions src/liger_kernel/transformers/model/glm4v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import torch

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils.deprecation import deprecate_kwarg

from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss


@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).

Returns:

Example:

```python
>>> from PIL import Image
>>> from transformers import AutoTokenizer, Glm4vForConditionalGeneration

>>> MODEL_PATH = "THUDM/GLM-4.1V-9B-Thinking"
>>> messages = [
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png"
},
{
"type": "text",
"text": "describe this image"
}
],
}
]
>>> processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
>>> model = Glm4vForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path=MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="auto",
)
>>> inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
>>> generated_ids = model.generate(**inputs, max_new_tokens=8192)
output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
<think>Got it, let's describe the image. First, there's a vintage car, specifically a Volkswagen Beetle
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]

shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None

if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")

if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and (labels is not None or shift_labels is not None)

if skip_logits:
loss = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)

else:
logits = self.lm_head(kept_hidden_states)
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
90 changes: 90 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1839,13 +1839,103 @@ def apply_liger_kernel_to_glm4(
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)


def apply_liger_kernel_to_glm4v(
rope: bool = False,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)

from transformers.models.glm4v import modeling_glm4v
from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
from transformers.models.glm4v.modeling_glm4v import Glm4vModel
from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel

from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4

if rope:
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
if rms_norm:
modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
if cross_entropy:
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(glm4v_lce_forward, model)
else:
modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward

if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
# Note: language_model and visual properties can be accessed throught conditional class for BC.
# Not sure if it is subject to changes in the future.
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
text_model: Glm4vTextModel = model.language_model
vision_model: Glm4vVisionModel = model.visual
elif isinstance(model, Glm4vTextModel):
text_model: Glm4vTextModel = model
vision_model = None
else:
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
raise TypeError(
f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
)

if vision_model is not None:
for vision_block in vision_model.blocks:
if rms_norm:
_patch_rms_norm_module(vision_block.norm1)
_patch_rms_norm_module(vision_block.norm2)
if swiglu:
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)

if text_model is not None:
if rms_norm:
_patch_rms_norm_module(text_model.norm)
for decoder_layer in text_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm)


# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
MODEL_TYPE_TO_APPLY_LIGER_FN = {
"gemma": apply_liger_kernel_to_gemma,
"gemma2": apply_liger_kernel_to_gemma2,
"gemma3_text": apply_liger_kernel_to_gemma3_text,
"gemma3": apply_liger_kernel_to_gemma3,
"glm4": apply_liger_kernel_to_glm4,
"glm4v": apply_liger_kernel_to_glm4v,
"llama": apply_liger_kernel_to_llama,
"llama4_text": apply_liger_kernel_to_llama4,
"llama4": apply_liger_kernel_to_llama4,
Expand Down
Loading
Loading