1
1
from typing import List , Optional , Tuple , Union
2
2
3
3
import torch
4
- import torch .nn .functional as F
5
4
from torch .nn import CrossEntropyLoss
5
+ from transformers .cache_utils import Cache
6
6
from transformers .modeling_outputs import CausalLMOutputWithPast
7
7
from transformers .models .llama .modeling_llama import (
8
8
_CONFIG_FOR_DOC ,
22
22
@replace_return_docstrings (
23
23
output_type = CausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC
24
24
)
25
- def lce_forward_deprecated (
25
+ def lce_forward (
26
26
self ,
27
27
input_ids : torch .LongTensor = None ,
28
28
attention_mask : Optional [torch .Tensor ] = None ,
29
29
position_ids : Optional [torch .LongTensor ] = None ,
30
- past_key_values : Optional [List [torch .FloatTensor ]] = None ,
30
+ past_key_values : Optional [Union [ Cache , List [torch .FloatTensor ] ]] = None ,
31
31
inputs_embeds : Optional [torch .FloatTensor ] = None ,
32
32
labels : Optional [torch .LongTensor ] = None ,
33
33
use_cache : Optional [bool ] = None ,
34
34
output_attentions : Optional [bool ] = None ,
35
35
output_hidden_states : Optional [bool ] = None ,
36
36
return_dict : Optional [bool ] = None ,
37
37
cache_position : Optional [torch .LongTensor ] = None ,
38
- ) -> Union [Tuple , CausalLMOutputWithPast ]:
39
- r"""
40
- Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
41
-
42
-
43
- Args:
44
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
-
49
- Returns:
50
-
51
- Example:
52
-
53
- ```python
54
- >>> from transformers import AutoTokenizer, LlamaForCausalLM
55
-
56
- >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
57
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
58
-
59
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
60
- >>> inputs = tokenizer(prompt, return_tensors="pt")
61
-
62
- >>> # Generate
63
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
64
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
65
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
66
- ```"""
67
- output_attentions = (
68
- output_attentions
69
- if output_attentions is not None
70
- else self .config .output_attentions
71
- )
72
- output_hidden_states = (
73
- output_hidden_states
74
- if output_hidden_states is not None
75
- else self .config .output_hidden_states
76
- )
77
- return_dict = (
78
- return_dict if return_dict is not None else self .config .use_return_dict
79
- )
80
-
81
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82
- outputs = self .model (
83
- input_ids = input_ids ,
84
- attention_mask = attention_mask ,
85
- position_ids = position_ids ,
86
- past_key_values = past_key_values ,
87
- inputs_embeds = inputs_embeds ,
88
- use_cache = use_cache ,
89
- output_attentions = output_attentions ,
90
- output_hidden_states = output_hidden_states ,
91
- return_dict = return_dict ,
92
- cache_position = cache_position ,
93
- )
94
-
95
- hidden_states = outputs [0 ]
96
-
97
- loss = None
98
- logits = None
99
-
100
- if self .training and (labels is not None ):
101
- shift_hidden_states = hidden_states [..., :- 1 , :].contiguous ()
102
- shift_labels = labels [..., 1 :].contiguous ()
103
-
104
- # flatten tokens
105
- shift_hidden_states = shift_hidden_states .view (- 1 , self .config .hidden_size )
106
- shift_labels = shift_labels .view (- 1 )
107
-
108
- lce = LigerFusedLinearCrossEntropyLoss ()
109
- loss = lce (self .lm_head .weight , shift_hidden_states , shift_labels )
110
-
111
- else :
112
- if self .config .pretraining_tp > 1 :
113
- lm_head_slices = self .lm_head .weight .split (
114
- self .vocab_size // self .config .pretraining_tp , dim = 0
115
- )
116
- logits = [
117
- F .linear (hidden_states , lm_head_slices [i ])
118
- for i in range (self .config .pretraining_tp )
119
- ]
120
- logits = torch .cat (logits , dim = - 1 )
121
- else :
122
- logits = self .lm_head (hidden_states )
123
- if labels is not None :
124
- # Upcast to float if we need to compute the loss to avoid potential precision issues
125
- logits = logits .float ()
126
- # Shift so that tokens < n predict n
127
- shift_logits = logits [..., :- 1 , :].contiguous ()
128
- shift_labels = labels [..., 1 :].contiguous ()
129
- # Flatten the tokens
130
- loss_fct = CrossEntropyLoss ()
131
- shift_logits = shift_logits .view (- 1 , self .config .vocab_size )
132
- shift_labels = shift_labels .view (- 1 )
133
- # Enable model parallelism
134
- shift_labels = shift_labels .to (shift_logits .device )
135
- loss = loss_fct (shift_logits , shift_labels )
136
-
137
- if not return_dict :
138
- output = (logits ,) + outputs [1 :]
139
- return (loss ,) + output if loss is not None else output
140
-
141
- return CausalLMOutputWithPast (
142
- loss = loss ,
143
- logits = logits ,
144
- past_key_values = outputs .past_key_values ,
145
- hidden_states = outputs .hidden_states ,
146
- attentions = outputs .attentions ,
147
- )
148
-
149
-
150
- @add_start_docstrings_to_model_forward (LLAMA_INPUTS_DOCSTRING )
151
- @replace_return_docstrings (
152
- output_type = CausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC
153
- )
154
- def lce_forward (
155
- self ,
156
- input_ids = None ,
157
- attention_mask = None ,
158
- position_ids = None ,
159
- past_key_values = None ,
160
- inputs_embeds = None ,
161
- labels = None ,
162
- use_cache = None ,
163
- output_attentions = None ,
164
- output_hidden_states = None ,
165
- return_dict = None ,
166
- cache_position = None ,
167
- num_logits_to_keep = 0 ,
168
- ** kwargs ,
38
+ num_logits_to_keep : int = 0 ,
39
+ ** loss_kwargs ,
169
40
) -> Union [Tuple , CausalLMOutputWithPast ]:
170
41
r"""
171
42
Args:
@@ -224,7 +95,6 @@ def lce_forward(
224
95
output_hidden_states = output_hidden_states ,
225
96
return_dict = return_dict ,
226
97
cache_position = cache_position ,
227
- ** kwargs ,
228
98
)
229
99
230
100
hidden_states = outputs [0 ]
@@ -245,22 +115,37 @@ def lce_forward(
245
115
shift_hidden_states = shift_hidden_states .view (- 1 , self .config .hidden_size )
246
116
shift_labels = shift_labels .view (- 1 )
247
117
248
- reduction = "sum" if "num_items_in_batch" in kwargs else "mean"
118
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
249
119
lce = LigerFusedLinearCrossEntropyLoss (reduction = reduction )
250
120
251
121
loss = lce (self .lm_head .weight , shift_hidden_states , shift_labels )
252
122
if reduction == "sum" :
253
- loss /= kwargs ["num_items_in_batch" ]
123
+ loss /= loss_kwargs ["num_items_in_batch" ]
254
124
255
- else : # if in inference mode materialize logits
125
+ elif hasattr ( self , "loss_function" ) : # if in inference mode materialize logits
256
126
logits = self .lm_head (hidden_states [:, - num_logits_to_keep :, :])
257
127
if labels is not None :
258
128
loss = self .loss_function (
259
129
logits = logits ,
260
130
labels = labels ,
261
131
vocab_size = self .config .vocab_size ,
262
- ** kwargs ,
132
+ ** loss_kwargs ,
263
133
)
134
+ else :
135
+ logits = self .lm_head (hidden_states [:, - num_logits_to_keep :, :])
136
+ if labels is not None :
137
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
138
+ logits = logits .float ()
139
+ # Shift so that tokens < n predict n
140
+ shift_logits = logits [..., :- 1 , :].contiguous ()
141
+ shift_labels = labels [..., 1 :].contiguous ()
142
+ # Flatten the tokens
143
+ loss_fct = CrossEntropyLoss ()
144
+ shift_logits = shift_logits .view (- 1 , self .config .vocab_size )
145
+ shift_labels = shift_labels .view (- 1 )
146
+ # Enable model parallelism
147
+ shift_labels = shift_labels .to (shift_logits .device )
148
+ loss = loss_fct (shift_logits , shift_labels )
264
149
265
150
if not return_dict :
266
151
output = (logits ,) + outputs [1 :]
0 commit comments