Skip to content

🐇 [Research] Layer Skip SFT #3111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/research_projects/layer_skip/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## Run training
```
cd scripts
python layer_skip_sft.py
```

## Run benchmark
```
cd scripts
python benchmark_layer_skip.py
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import config
import torch
from torch.utils import benchmark
from transformers import AutoModelForCausalLM, AutoTokenizer


def generate_tokens(model, inputs):
outputs = model.generate(
**inputs,
do_sample=False,
max_new_tokens=64,
)
return outputs


def generate_assistant_tokens(model, inputs, assistant_early_exit):
outputs = model.generate(
**inputs,
assistant_early_exit=assistant_early_exit,
do_sample=False,
max_new_tokens=64,
)
return outputs


if __name__ == "__main__":
ckpt = config.hub_model_id

model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(ckpt)

prompt = "### Instruction: What are my alarms for the rest of the day?\n ### Response: "

results = []
label = "Generation Speeds"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

results.append(
benchmark.Timer(
stmt="generate_tokens(model, inputs)",
setup="from __main__ import generate_tokens",
globals={"model": model, "inputs": inputs},
num_threads=torch.get_num_threads(),
label=label,
sub_label="no layer skip",
description="generation",
).blocked_autorange()
)

for i in range(1, 16):
results.append(
benchmark.Timer(
stmt="generate_assistant_tokens(model, inputs, assistant_early_exit)",
setup="from __main__ import generate_assistant_tokens",
globals={"model": model, "assistant_early_exit": i, "inputs": inputs},
num_threads=torch.get_num_threads(),
label=label,
sub_label=f"layer skip {i}",
description="generation",
).blocked_autorange()
)

benchmark.Compare(results).print()
11 changes: 11 additions & 0 deletions examples/research_projects/layer_skip/scripts/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
model_name = "unsloth/Llama-3.2-3B"
tokenizer_name = "unsloth/Llama-3.2-3B"
dataset_name = "WillHeld/top_v2"

output_root_dir = "./checkpoints/"
hub_model_id = f"ariG23498/layerskip-{model_name.split('/')[1]}-{dataset_name.split('/')[1]}"
output_dir = f"{output_root_dir}/{hub_model_id}"

per_device_train_batch_size = 8
gradient_accumulation_steps = 1
learning_rate = 2e-5
34 changes: 34 additions & 0 deletions examples/research_projects/layer_skip/scripts/custom_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from trl import SFTTrainer


class LayerSkipSFTTrainer(SFTTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.early_exit_layer = 0 # initialize with 0
self.always_last_layer = True
self.early_exit_loss_scale = 1.0

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
self.early_exit_layer = (
self.early_exit_layer % (model.config.num_hidden_layers - 1)
) + 1 # rotates between [1, num_hidden_layers-1]
bs, seqlen = inputs.input_ids.shape

labels = inputs.pop("labels")
outputs = model(**inputs, output_hidden_states=True)

hidden_state = outputs["hidden_states"][self.early_exit_layer].to(model.dtype)
if self.early_exit_layer != model.config.num_hidden_layers:
hidden_state = model.model.norm(hidden_state)
logits = model.lm_head(hidden_state)
loss_early = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size)

if self.always_last_layer:
loss_last = model.loss_function(logits=outputs["logits"], labels=labels, vocab_size=model.vocab_size)
loss = self.early_exit_loss_scale * loss_early.to(loss_last.device) + 1.0 * loss_last
# normalize loss scales
loss = loss / (1.0 + self.early_exit_loss_scale)
else:
loss = loss_early

return loss
77 changes: 77 additions & 0 deletions examples/research_projects/layer_skip/scripts/layer_skip_sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import config
from custom_trainer import LayerSkipSFTTrainer

from trl import SFTConfig, DataCollatorForCompletionOnlyLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch


def formatting_prompts_func(example):
text = f"### Instruction: {example['utterance']}\n ### Response: {example['semantic_parse']}"

# Inject eos_token as a string before tokenization, because they are not always added
# See: https://github.com/huggingface/transformers/issues/22794 and
# https://github.com/huggingface/trl/issues/1623
if tokenizer.eos_token: # usually something like "</s>" for GPT2 or "<|endoftext|>"
text += f"{tokenizer.eos_token}"

return text


if __name__ == "__main__":
# load the dataset
print("[INFO] loading the dataset...")
train_dataset = load_dataset(config.dataset_name, split="train")

print(f"output_root_dir: {config.output_root_dir}")
print(f"hub_model_id: {config.hub_model_id}")

# load the model and tokenizer
print("[INFO] loading the model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, add_eos_token=True)

# adding pad and eos tokens if not provided in the tokenizer
if tokenizer.pad_token is None:
# Add '[PAD]' token if it doesn't exist
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id

if tokenizer.eos_token is None or tokenizer.eos_token == tokenizer.bos_token:
# Add '[EOS]' token if it doesn't exist
tokenizer.add_special_tokens({"eos_token": "[EOS]"})
model.resize_token_embeddings(len(tokenizer))
model.config.eos_token_id = tokenizer.eos_token_id

response_template = " ### Response:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

args = SFTConfig(
do_train=True,
bf16=True,
max_seq_length=None,
per_device_train_batch_size=config.per_device_train_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
packing=False,
num_train_epochs=1.0,
report_to="none",
push_to_hub=True,
hub_model_id=config.hub_model_id,
output_dir=config.output_dir,
logging_steps=500,
save_steps=1000,
save_total_limit=2,
)

trainer = LayerSkipSFTTrainer(
model,
train_dataset=train_dataset,
args=args,
formatting_func=formatting_prompts_func,
data_collator=collator,
)

trainer.train()
Loading