Skip to content

Commit 04d5a0e

Browse files
author
Matthew Hoffman
authored
Move logits.float() call (#308)
## Summary The analogous `logits.float()` calls were moved in the Hugging Face modeling source code to be inside the `if labels is not None` block to avoid upcasting logits unless they are being used in a loss calculation; this avoids a memory spike during inference if the model is in lower precision. * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/llama/modeling_llama.py#L1211-L1212 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/mixtral/modeling_mixtral.py#L1329-L1330 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/phi3/modeling_phi3.py#L1303-L1304 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/qwen2/modeling_qwen2.py#L1206-L1207 Some of your models already have this change: https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/transformers/model/mistral.py#L114-L116 https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/transformers/model/gemma.py#L114-L116 See also: * huggingface/transformers#30860 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent ff6650b commit 04d5a0e

File tree

5 files changed

+10
-5
lines changed

5 files changed

+10
-5
lines changed

src/liger_kernel/transformers/model/llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,9 @@ def lce_forward(
120120
logits = torch.cat(logits, dim=-1)
121121
else:
122122
logits = self.lm_head(hidden_states)
123-
logits = logits.float()
124123
if labels is not None:
124+
# Upcast to float if we need to compute the loss to avoid potential precision issues
125+
logits = logits.float()
125126
# Shift so that tokens < n predict n
126127
shift_logits = logits[..., :-1, :].contiguous()
127128
shift_labels = labels[..., 1:].contiguous()

src/liger_kernel/transformers/model/mixtral.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def lce_forward(
103103

104104
hidden_states = outputs[0]
105105
logits = self.lm_head(hidden_states)
106-
logits = logits.float()
107106

108107
loss = None
109108
if self.training and (labels is not None):
@@ -116,6 +115,8 @@ def lce_forward(
116115
lce = LigerFusedLinearCrossEntropyLoss()
117116
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
118117
elif labels is not None:
118+
# Upcast to float if we need to compute the loss to avoid potential precision issues
119+
logits = logits.float()
119120
# Shift so that tokens < n predict n
120121
shift_logits = logits[..., :-1, :].contiguous()
121122
shift_labels = labels[..., 1:].contiguous()

src/liger_kernel/transformers/model/phi3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,11 @@ def lce_forward(
108108
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
109109
else:
110110
logits = self.lm_head(hidden_states)
111-
logits = logits.float()
112111

113112
loss = None
114113
if labels is not None:
114+
# Upcast to float if we need to compute the loss to avoid potential precision issues
115+
logits = logits.float()
115116
# Shift so that tokens < n predict n
116117
shift_logits = logits[..., :-1, :].contiguous()
117118
shift_labels = labels[..., 1:].contiguous()

src/liger_kernel/transformers/model/qwen2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ def lce_forward(
109109

110110
else:
111111
logits = self.lm_head(hidden_states)
112-
logits = logits.float()
113112
if labels is not None:
113+
# Upcast to float if we need to compute the loss to avoid potential precision issues
114+
logits = logits.float()
114115
# Shift so that tokens < n predict n
115116
shift_logits = logits[..., :-1, :].contiguous()
116117
shift_labels = labels[..., 1:].contiguous()

src/liger_kernel/transformers/model/qwen2_vl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,9 @@ def lce_forward(
150150
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
151151
else:
152152
logits = self.lm_head(hidden_states)
153-
logits = logits.float()
154153
if labels is not None:
154+
# Upcast to float if we need to compute the loss to avoid potential precision issues
155+
logits = logits.float()
155156
# Shift so that tokens < n predict n
156157
shift_logits = logits[..., :-1, :].contiguous()
157158
shift_labels = labels[..., 1:].contiguous()

0 commit comments

Comments
 (0)