Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/research_projects/layer_skip/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# LayerSkip Training Recipe

Implements the training recipe as described in the [LayerSkip paper](https://huggingface.co/papers/2404.16710).

## Run training
```
cd scripts
python layer_skip_sft.py
```

## Run benchmark
```
cd scripts
python benchmark_layer_skip.py
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import config
import torch
from torch.utils import benchmark
from transformers import AutoModelForCausalLM, AutoTokenizer


def generate_tokens(model, inputs):
outputs = model.generate(
**inputs,
do_sample=False,
max_new_tokens=64,
)
return outputs


def generate_tokens_with_assistance(model, inputs, assistant_early_exit):
outputs = model.generate(
**inputs,
assistant_early_exit=assistant_early_exit,
do_sample=False,
max_new_tokens=64,
)
return outputs


if __name__ == "__main__":
ckpt = config.hub_model_id

model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(ckpt)

prompt = "### Instruction: What are my alarms for the rest of the day?\n ### Response: "

results = []
label = "Generation Times"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

results.append(
benchmark.Timer(
stmt="generate_tokens(model, inputs)",
setup="from __main__ import generate_tokens",
globals={"model": model, "inputs": inputs},
num_threads=torch.get_num_threads(),
label=label,
sub_label="no layer skip",
description="generation",
).blocked_autorange()
)

for i in range(1, model.config.num_hidden_layers):
results.append(
benchmark.Timer(
stmt="generate_tokens_with_assistance(model, inputs, assistant_early_exit)",
setup="from __main__ import generate_assistant_tokens",
globals={"model": model, "assistant_early_exit": i, "inputs": inputs},
num_threads=torch.get_num_threads(),
label=label,
sub_label=f"layer skip {i}",
description="generation",
).blocked_autorange()
)

benchmark.Compare(results).print()
28 changes: 28 additions & 0 deletions examples/research_projects/layer_skip/scripts/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from huggingface_hub import whoami


model_name = "unsloth/Llama-3.2-3B"
tokenizer_name = "unsloth/Llama-3.2-3B"
dataset_name = "WillHeld/top_v2"

output_root_dir = "./checkpoints/"
hub_model_id = f"{whoami()['name']}/layerskip-{model_name.split('/')[1]}-{dataset_name.split('/')[1]}"
output_dir = f"{output_root_dir}/{hub_model_id}"

per_device_train_batch_size = 8
gradient_accumulation_steps = 1
learning_rate = 2e-5
48 changes: 48 additions & 0 deletions examples/research_projects/layer_skip/scripts/custom_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from trl import SFTTrainer


class LayerSkipSFTTrainer(SFTTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.early_exit_layer = 0 # initialize with 0
self.always_last_layer = True
self.early_exit_loss_scale = 1.0

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
self.early_exit_layer = (
self.early_exit_layer % (model.config.num_hidden_layers - 1)
) + 1 # rotates between [1, num_hidden_layers-1]
bs, seqlen = inputs.input_ids.shape

labels = inputs.pop("labels")
outputs = model(**inputs, output_hidden_states=True)

hidden_state = outputs["hidden_states"][self.early_exit_layer].to(model.dtype)
if self.early_exit_layer != model.config.num_hidden_layers:
hidden_state = model.model.norm(hidden_state)
logits = model.lm_head(hidden_state)
loss_early = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size)

if self.always_last_layer:
loss_last = model.loss_function(logits=outputs["logits"], labels=labels, vocab_size=model.vocab_size)
loss = self.early_exit_loss_scale * loss_early.to(loss_last.device) + 1.0 * loss_last
# normalize loss scales
loss = loss / (1.0 + self.early_exit_loss_scale)
else:
loss = loss_early

return loss
91 changes: 91 additions & 0 deletions examples/research_projects/layer_skip/scripts/layer_skip_sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import config
import torch
from custom_trainer import LayerSkipSFTTrainer
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl import DataCollatorForCompletionOnlyLM, SFTConfig


def formatting_prompts_func(example):
text = f"### Instruction: {example['utterance']}\n ### Response: {example['semantic_parse']}"

# Inject eos_token as a string before tokenization, because they are not always added
# See: https://github.com/huggingface/transformers/issues/22794 and
# https://github.com/huggingface/trl/issues/1623
if tokenizer.eos_token: # usually something like "</s>" for GPT2 or "<|endoftext|>"
text += f"{tokenizer.eos_token}"

return text


if __name__ == "__main__":
# load the dataset
print("[INFO] loading the dataset...")
train_dataset = load_dataset(config.dataset_name, split="train")

print(f"output_root_dir: {config.output_root_dir}")
print(f"hub_model_id: {config.hub_model_id}")

# load the model and tokenizer
print("[INFO] loading the model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, add_eos_token=True)

# adding pad and eos tokens if not provided in the tokenizer
if tokenizer.pad_token is None:
# Add '[PAD]' token if it doesn't exist
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id

if tokenizer.eos_token is None or tokenizer.eos_token == tokenizer.bos_token:
# Add '[EOS]' token if it doesn't exist
tokenizer.add_special_tokens({"eos_token": "[EOS]"})
model.resize_token_embeddings(len(tokenizer))
model.config.eos_token_id = tokenizer.eos_token_id

response_template = " ### Response:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

args = SFTConfig(
do_train=True,
bf16=True,
max_seq_length=None,
per_device_train_batch_size=config.per_device_train_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
packing=False,
num_train_epochs=1.0,
report_to="none",
push_to_hub=True,
hub_model_id=config.hub_model_id,
output_dir=config.output_dir,
logging_steps=500,
save_steps=1000,
save_total_limit=2,
)

trainer = LayerSkipSFTTrainer(
model,
train_dataset=train_dataset,
args=args,
formatting_func=formatting_prompts_func,
data_collator=collator,
)

trainer.train()
Loading