File tree Expand file tree Collapse file tree 5 files changed +10
-5
lines changed
src/liger_kernel/transformers/model Expand file tree Collapse file tree 5 files changed +10
-5
lines changed Original file line number Diff line number Diff line change @@ -120,8 +120,9 @@ def lce_forward(
120
120
logits = torch .cat (logits , dim = - 1 )
121
121
else :
122
122
logits = self .lm_head (hidden_states )
123
- logits = logits .float ()
124
123
if labels is not None :
124
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
125
+ logits = logits .float ()
125
126
# Shift so that tokens < n predict n
126
127
shift_logits = logits [..., :- 1 , :].contiguous ()
127
128
shift_labels = labels [..., 1 :].contiguous ()
Original file line number Diff line number Diff line change @@ -103,7 +103,6 @@ def lce_forward(
103
103
104
104
hidden_states = outputs [0 ]
105
105
logits = self .lm_head (hidden_states )
106
- logits = logits .float ()
107
106
108
107
loss = None
109
108
if self .training and (labels is not None ):
@@ -116,6 +115,8 @@ def lce_forward(
116
115
lce = LigerFusedLinearCrossEntropyLoss ()
117
116
loss = lce (self .lm_head .weight , shift_hidden_states , shift_labels )
118
117
elif labels is not None :
118
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
119
+ logits = logits .float ()
119
120
# Shift so that tokens < n predict n
120
121
shift_logits = logits [..., :- 1 , :].contiguous ()
121
122
shift_labels = labels [..., 1 :].contiguous ()
Original file line number Diff line number Diff line change @@ -108,10 +108,11 @@ def lce_forward(
108
108
loss = lce (self .lm_head .weight , shift_hidden_states , shift_labels )
109
109
else :
110
110
logits = self .lm_head (hidden_states )
111
- logits = logits .float ()
112
111
113
112
loss = None
114
113
if labels is not None :
114
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
115
+ logits = logits .float ()
115
116
# Shift so that tokens < n predict n
116
117
shift_logits = logits [..., :- 1 , :].contiguous ()
117
118
shift_labels = labels [..., 1 :].contiguous ()
Original file line number Diff line number Diff line change @@ -109,8 +109,9 @@ def lce_forward(
109
109
110
110
else :
111
111
logits = self .lm_head (hidden_states )
112
- logits = logits .float ()
113
112
if labels is not None :
113
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
114
+ logits = logits .float ()
114
115
# Shift so that tokens < n predict n
115
116
shift_logits = logits [..., :- 1 , :].contiguous ()
116
117
shift_labels = labels [..., 1 :].contiguous ()
Original file line number Diff line number Diff line change @@ -150,8 +150,9 @@ def lce_forward(
150
150
loss = lce (self .lm_head .weight , shift_hidden_states , shift_labels )
151
151
else :
152
152
logits = self .lm_head (hidden_states )
153
- logits = logits .float ()
154
153
if labels is not None :
154
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
155
+ logits = logits .float ()
155
156
# Shift so that tokens < n predict n
156
157
shift_logits = logits [..., :- 1 , :].contiguous ()
157
158
shift_labels = labels [..., 1 :].contiguous ()
You can’t perform that action at this time.
0 commit comments