Skip to content

Commit 49d74ce

Browse files
committed
initial commit
1 parent eee9ec9 commit 49d74ce

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

train_grpo.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from datasets import load_dataset
2+
from trl import GRPOConfig, GRPOTrainer
3+
4+
dataset = load_dataset("trl-lib/tldr", split="train")
5+
6+
# Define the reward function, which rewards completions that are close to 20 characters
7+
def reward_len(completions, **kwargs):
8+
return [-abs(20 - len(completion)) for completion in completions]
9+
10+
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO")
11+
trainer = GRPOTrainer(
12+
model="Qwen/Qwen2-0.5B-Instruct",
13+
reward_funcs=reward_len,
14+
args=training_args,
15+
train_dataset=dataset,
16+
)
17+
trainer.train()

0 commit comments

Comments
 (0)