We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent eee9ec9 commit 49d74ceCopy full SHA for 49d74ce
train_grpo.py
@@ -0,0 +1,17 @@
1
+from datasets import load_dataset
2
+from trl import GRPOConfig, GRPOTrainer
3
+
4
+dataset = load_dataset("trl-lib/tldr", split="train")
5
6
+# Define the reward function, which rewards completions that are close to 20 characters
7
+def reward_len(completions, **kwargs):
8
+ return [-abs(20 - len(completion)) for completion in completions]
9
10
+training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO")
11
+trainer = GRPOTrainer(
12
+ model="Qwen/Qwen2-0.5B-Instruct",
13
+ reward_funcs=reward_len,
14
+ args=training_args,
15
+ train_dataset=dataset,
16
+)
17
+trainer.train()
0 commit comments