Skip to content

Commit 52bb235

Browse files
ariG23498mostafaelhoushiqgallouedec
authored andcommitted
🐇 [Research] Layer Skip SFT (huggingface#3111)
Co-authored-by: Mostafa Elhoushi <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 20d6fef commit 52bb235

File tree

5 files changed

+259
-0
lines changed

5 files changed

+259
-0
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# LayerSkip Training Recipe
2+
3+
Implements the training recipe as described in the [LayerSkip paper](https://huggingface.co/papers/2404.16710).
4+
5+
## Run training
6+
```
7+
cd scripts
8+
python layer_skip_sft.py
9+
```
10+
11+
## Run benchmark
12+
```
13+
cd scripts
14+
python benchmark_layer_skip.py
15+
```
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import config
16+
import torch
17+
from torch.utils import benchmark
18+
from transformers import AutoModelForCausalLM, AutoTokenizer
19+
20+
21+
def generate_tokens(model, inputs):
22+
outputs = model.generate(
23+
**inputs,
24+
do_sample=False,
25+
max_new_tokens=64,
26+
)
27+
return outputs
28+
29+
30+
def generate_tokens_with_assistance(model, inputs, assistant_early_exit):
31+
outputs = model.generate(
32+
**inputs,
33+
assistant_early_exit=assistant_early_exit,
34+
do_sample=False,
35+
max_new_tokens=64,
36+
)
37+
return outputs
38+
39+
40+
if __name__ == "__main__":
41+
ckpt = config.hub_model_id
42+
43+
model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16)
44+
tokenizer = AutoTokenizer.from_pretrained(ckpt)
45+
46+
prompt = "### Instruction: What are my alarms for the rest of the day?\n ### Response: "
47+
48+
results = []
49+
label = "Generation Times"
50+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
51+
52+
results.append(
53+
benchmark.Timer(
54+
stmt="generate_tokens(model, inputs)",
55+
setup="from __main__ import generate_tokens",
56+
globals={"model": model, "inputs": inputs},
57+
num_threads=torch.get_num_threads(),
58+
label=label,
59+
sub_label="no layer skip",
60+
description="generation",
61+
).blocked_autorange()
62+
)
63+
64+
for i in range(1, model.config.num_hidden_layers):
65+
results.append(
66+
benchmark.Timer(
67+
stmt="generate_tokens_with_assistance(model, inputs, assistant_early_exit)",
68+
setup="from __main__ import generate_assistant_tokens",
69+
globals={"model": model, "assistant_early_exit": i, "inputs": inputs},
70+
num_threads=torch.get_num_threads(),
71+
label=label,
72+
sub_label=f"layer skip {i}",
73+
description="generation",
74+
).blocked_autorange()
75+
)
76+
77+
benchmark.Compare(results).print()
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from huggingface_hub import whoami
16+
17+
18+
model_name = "unsloth/Llama-3.2-3B"
19+
tokenizer_name = "unsloth/Llama-3.2-3B"
20+
dataset_name = "WillHeld/top_v2"
21+
22+
output_root_dir = "./checkpoints/"
23+
hub_model_id = f"{whoami()['name']}/layerskip-{model_name.split('/')[1]}-{dataset_name.split('/')[1]}"
24+
output_dir = f"{output_root_dir}/{hub_model_id}"
25+
26+
per_device_train_batch_size = 8
27+
gradient_accumulation_steps = 1
28+
learning_rate = 2e-5
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from trl import SFTTrainer
16+
17+
18+
class LayerSkipSFTTrainer(SFTTrainer):
19+
def __init__(self, *args, **kwargs):
20+
super().__init__(*args, **kwargs)
21+
self.early_exit_layer = 0 # initialize with 0
22+
self.always_last_layer = True
23+
self.early_exit_loss_scale = 1.0
24+
25+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
26+
self.early_exit_layer = (
27+
self.early_exit_layer % (model.config.num_hidden_layers - 1)
28+
) + 1 # rotates between [1, num_hidden_layers-1]
29+
bs, seqlen = inputs.input_ids.shape
30+
31+
labels = inputs.pop("labels")
32+
outputs = model(**inputs, output_hidden_states=True)
33+
34+
hidden_state = outputs["hidden_states"][self.early_exit_layer].to(model.dtype)
35+
if self.early_exit_layer != model.config.num_hidden_layers:
36+
hidden_state = model.model.norm(hidden_state)
37+
logits = model.lm_head(hidden_state)
38+
loss_early = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size)
39+
40+
if self.always_last_layer:
41+
loss_last = model.loss_function(logits=outputs["logits"], labels=labels, vocab_size=model.vocab_size)
42+
loss = self.early_exit_loss_scale * loss_early.to(loss_last.device) + 1.0 * loss_last
43+
# normalize loss scales
44+
loss = loss / (1.0 + self.early_exit_loss_scale)
45+
else:
46+
loss = loss_early
47+
48+
return loss
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import config
16+
import torch
17+
from custom_trainer import LayerSkipSFTTrainer
18+
from datasets import load_dataset
19+
from transformers import AutoModelForCausalLM, AutoTokenizer
20+
21+
from trl import DataCollatorForCompletionOnlyLM, SFTConfig
22+
23+
24+
def formatting_prompts_func(example):
25+
text = f"### Instruction: {example['utterance']}\n ### Response: {example['semantic_parse']}"
26+
27+
# Inject eos_token as a string before tokenization, because they are not always added
28+
# See: https://github.com/huggingface/transformers/issues/22794 and
29+
# https://github.com/huggingface/trl/issues/1623
30+
if tokenizer.eos_token: # usually something like "</s>" for GPT2 or "<|endoftext|>"
31+
text += f"{tokenizer.eos_token}"
32+
33+
return text
34+
35+
36+
if __name__ == "__main__":
37+
# load the dataset
38+
print("[INFO] loading the dataset...")
39+
train_dataset = load_dataset(config.dataset_name, split="train")
40+
41+
print(f"output_root_dir: {config.output_root_dir}")
42+
print(f"hub_model_id: {config.hub_model_id}")
43+
44+
# load the model and tokenizer
45+
print("[INFO] loading the model and tokenizer...")
46+
model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", torch_dtype=torch.bfloat16)
47+
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, add_eos_token=True)
48+
49+
# adding pad and eos tokens if not provided in the tokenizer
50+
if tokenizer.pad_token is None:
51+
# Add '[PAD]' token if it doesn't exist
52+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
53+
model.resize_token_embeddings(len(tokenizer))
54+
model.config.pad_token_id = tokenizer.pad_token_id
55+
56+
if tokenizer.eos_token is None or tokenizer.eos_token == tokenizer.bos_token:
57+
# Add '[EOS]' token if it doesn't exist
58+
tokenizer.add_special_tokens({"eos_token": "[EOS]"})
59+
model.resize_token_embeddings(len(tokenizer))
60+
model.config.eos_token_id = tokenizer.eos_token_id
61+
62+
response_template = " ### Response:"
63+
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
64+
65+
args = SFTConfig(
66+
do_train=True,
67+
bf16=True,
68+
max_seq_length=None,
69+
per_device_train_batch_size=config.per_device_train_batch_size,
70+
gradient_accumulation_steps=config.gradient_accumulation_steps,
71+
learning_rate=config.learning_rate,
72+
packing=False,
73+
num_train_epochs=1.0,
74+
report_to="none",
75+
push_to_hub=True,
76+
hub_model_id=config.hub_model_id,
77+
output_dir=config.output_dir,
78+
logging_steps=500,
79+
save_steps=1000,
80+
save_total_limit=2,
81+
)
82+
83+
trainer = LayerSkipSFTTrainer(
84+
model,
85+
train_dataset=train_dataset,
86+
args=args,
87+
formatting_func=formatting_prompts_func,
88+
data_collator=collator,
89+
)
90+
91+
trainer.train()

0 commit comments

Comments
 (0)