Skip to content

Commit 9618403

Browse files
committed
fix llama & gemma's forward patch
1 parent 1b04de6 commit 9618403

File tree

3 files changed

+45
-148
lines changed

3 files changed

+45
-148
lines changed

src/liger_kernel/transformers/model/gemma.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def lce_forward(
3535
output_hidden_states: Optional[bool] = None,
3636
return_dict: Optional[bool] = None,
3737
cache_position: Optional[torch.LongTensor] = None,
38+
num_logits_to_keep: int = 0,
39+
**loss_kwargs,
3840
) -> Union[Tuple, CausalLMOutputWithPast]:
3941
r"""
4042
@@ -106,11 +108,25 @@ def lce_forward(
106108
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
107109
shift_labels = shift_labels.view(-1)
108110

109-
lce = LigerFusedLinearCrossEntropyLoss()
111+
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
112+
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
113+
110114
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
115+
if reduction == "sum":
116+
loss /= loss_kwargs["num_items_in_batch"]
117+
118+
elif hasattr(self, "loss_function"): # if in inference mode materialize logits
119+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
120+
if labels is not None:
121+
loss = self.loss_function(
122+
logits=logits,
123+
labels=labels,
124+
vocab_size=self.config.vocab_size,
125+
**loss_kwargs,
126+
)
111127

112128
else:
113-
logits = self.lm_head(hidden_states)
129+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
114130
if labels is not None:
115131
# Upcast to float if we need to compute the loss to avoid potential precision issues
116132
logits = logits.float()

src/liger_kernel/transformers/model/llama.py

Lines changed: 24 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import List, Optional, Tuple, Union
22

33
import torch
4-
import torch.nn.functional as F
54
from torch.nn import CrossEntropyLoss
5+
from transformers.cache_utils import Cache
66
from transformers.modeling_outputs import CausalLMOutputWithPast
77
from transformers.models.llama.modeling_llama import (
88
_CONFIG_FOR_DOC,
@@ -22,150 +22,21 @@
2222
@replace_return_docstrings(
2323
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
2424
)
25-
def lce_forward_deprecated(
25+
def lce_forward(
2626
self,
2727
input_ids: torch.LongTensor = None,
2828
attention_mask: Optional[torch.Tensor] = None,
2929
position_ids: Optional[torch.LongTensor] = None,
30-
past_key_values: Optional[List[torch.FloatTensor]] = None,
30+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
3131
inputs_embeds: Optional[torch.FloatTensor] = None,
3232
labels: Optional[torch.LongTensor] = None,
3333
use_cache: Optional[bool] = None,
3434
output_attentions: Optional[bool] = None,
3535
output_hidden_states: Optional[bool] = None,
3636
return_dict: Optional[bool] = None,
3737
cache_position: Optional[torch.LongTensor] = None,
38-
) -> Union[Tuple, CausalLMOutputWithPast]:
39-
r"""
40-
Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
41-
42-
43-
Args:
44-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48-
49-
Returns:
50-
51-
Example:
52-
53-
```python
54-
>>> from transformers import AutoTokenizer, LlamaForCausalLM
55-
56-
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
57-
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
58-
59-
>>> prompt = "Hey, are you conscious? Can you talk to me?"
60-
>>> inputs = tokenizer(prompt, return_tensors="pt")
61-
62-
>>> # Generate
63-
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
64-
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
65-
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
66-
```"""
67-
output_attentions = (
68-
output_attentions
69-
if output_attentions is not None
70-
else self.config.output_attentions
71-
)
72-
output_hidden_states = (
73-
output_hidden_states
74-
if output_hidden_states is not None
75-
else self.config.output_hidden_states
76-
)
77-
return_dict = (
78-
return_dict if return_dict is not None else self.config.use_return_dict
79-
)
80-
81-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82-
outputs = self.model(
83-
input_ids=input_ids,
84-
attention_mask=attention_mask,
85-
position_ids=position_ids,
86-
past_key_values=past_key_values,
87-
inputs_embeds=inputs_embeds,
88-
use_cache=use_cache,
89-
output_attentions=output_attentions,
90-
output_hidden_states=output_hidden_states,
91-
return_dict=return_dict,
92-
cache_position=cache_position,
93-
)
94-
95-
hidden_states = outputs[0]
96-
97-
loss = None
98-
logits = None
99-
100-
if self.training and (labels is not None):
101-
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
102-
shift_labels = labels[..., 1:].contiguous()
103-
104-
# flatten tokens
105-
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
106-
shift_labels = shift_labels.view(-1)
107-
108-
lce = LigerFusedLinearCrossEntropyLoss()
109-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
110-
111-
else:
112-
if self.config.pretraining_tp > 1:
113-
lm_head_slices = self.lm_head.weight.split(
114-
self.vocab_size // self.config.pretraining_tp, dim=0
115-
)
116-
logits = [
117-
F.linear(hidden_states, lm_head_slices[i])
118-
for i in range(self.config.pretraining_tp)
119-
]
120-
logits = torch.cat(logits, dim=-1)
121-
else:
122-
logits = self.lm_head(hidden_states)
123-
if labels is not None:
124-
# Upcast to float if we need to compute the loss to avoid potential precision issues
125-
logits = logits.float()
126-
# Shift so that tokens < n predict n
127-
shift_logits = logits[..., :-1, :].contiguous()
128-
shift_labels = labels[..., 1:].contiguous()
129-
# Flatten the tokens
130-
loss_fct = CrossEntropyLoss()
131-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
132-
shift_labels = shift_labels.view(-1)
133-
# Enable model parallelism
134-
shift_labels = shift_labels.to(shift_logits.device)
135-
loss = loss_fct(shift_logits, shift_labels)
136-
137-
if not return_dict:
138-
output = (logits,) + outputs[1:]
139-
return (loss,) + output if loss is not None else output
140-
141-
return CausalLMOutputWithPast(
142-
loss=loss,
143-
logits=logits,
144-
past_key_values=outputs.past_key_values,
145-
hidden_states=outputs.hidden_states,
146-
attentions=outputs.attentions,
147-
)
148-
149-
150-
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
151-
@replace_return_docstrings(
152-
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
153-
)
154-
def lce_forward(
155-
self,
156-
input_ids=None,
157-
attention_mask=None,
158-
position_ids=None,
159-
past_key_values=None,
160-
inputs_embeds=None,
161-
labels=None,
162-
use_cache=None,
163-
output_attentions=None,
164-
output_hidden_states=None,
165-
return_dict=None,
166-
cache_position=None,
167-
num_logits_to_keep=0,
168-
**kwargs,
38+
num_logits_to_keep: int = 0,
39+
**loss_kwargs,
16940
) -> Union[Tuple, CausalLMOutputWithPast]:
17041
r"""
17142
Args:
@@ -224,7 +95,6 @@ def lce_forward(
22495
output_hidden_states=output_hidden_states,
22596
return_dict=return_dict,
22697
cache_position=cache_position,
227-
**kwargs,
22898
)
22999

230100
hidden_states = outputs[0]
@@ -245,22 +115,37 @@ def lce_forward(
245115
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
246116
shift_labels = shift_labels.view(-1)
247117

248-
reduction = "sum" if "num_items_in_batch" in kwargs else "mean"
118+
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
249119
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
250120

251121
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
252122
if reduction == "sum":
253-
loss /= kwargs["num_items_in_batch"]
123+
loss /= loss_kwargs["num_items_in_batch"]
254124

255-
else: # if in inference mode materialize logits
125+
elif hasattr(self, "loss_function"): # if in inference mode materialize logits
256126
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
257127
if labels is not None:
258128
loss = self.loss_function(
259129
logits=logits,
260130
labels=labels,
261131
vocab_size=self.config.vocab_size,
262-
**kwargs,
132+
**loss_kwargs,
263133
)
134+
else:
135+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
136+
if labels is not None:
137+
# Upcast to float if we need to compute the loss to avoid potential precision issues
138+
logits = logits.float()
139+
# Shift so that tokens < n predict n
140+
shift_logits = logits[..., :-1, :].contiguous()
141+
shift_labels = labels[..., 1:].contiguous()
142+
# Flatten the tokens
143+
loss_fct = CrossEntropyLoss()
144+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
145+
shift_labels = shift_labels.view(-1)
146+
# Enable model parallelism
147+
shift_labels = shift_labels.to(shift_logits.device)
148+
loss = loss_fct(shift_logits, shift_labels)
264149

265150
if not return_dict:
266151
output = (logits,) + outputs[1:]

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
from liger_kernel.transformers.layer_norm import LigerLayerNorm
1313
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
1414
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
15-
from liger_kernel.transformers.model.llama import (
16-
lce_forward_deprecated as llama_lce_forward_deprecated,
17-
)
1815
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
1916
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
2017
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
@@ -95,14 +92,13 @@ def apply_liger_kernel_to_llama(
9592
if cross_entropy:
9693
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
9794
if fused_linear_cross_entropy:
98-
if transformer_version >= version.parse("4.46.0"):
99-
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
100-
else: # if version < 4.46.0
95+
if transformer_version < version.parse("4.46.0"):
10196
logger.warning(
10297
"Support for transformers versions < 4.46.0 will soon be discontinued due to issues with incorrect gradient accumulation. "
10398
"Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
10499
)
105-
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
100+
101+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
106102

107103
if model is not None:
108104
# The model instance already exists, so we need to additionally patch the

0 commit comments

Comments
 (0)