Skip to content

Commit 6faaa1b

Browse files
Add files via upload
1 parent 33b55ce commit 6faaa1b

File tree

1 file changed

+218
-44
lines changed

1 file changed

+218
-44
lines changed

Mincheol/run_ppo.py

Lines changed: 218 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,90 @@
11
# run_ppo.py
2-
2+
print(">>> Start script...ff", flush=True)
33
import warnings
4+
import re
5+
import random
6+
print(">>> Imported warnings", flush=True)
7+
import torch
8+
print(">>> Imported torch", flush=True)
49
from transformers import GPT2LMHeadModel, GPT2Tokenizer
10+
print(">>> Imported transformers", flush=True)
511
from datasets import load_from_disk
12+
print(">>> Imported datasets", flush=True)
613
from trl import PPOTrainer, PPOConfig
14+
print(">>> Imported PPOTrainer", flush=True)
715
from trl import AutoModelForCausalLMWithValueHead
16+
print(">>> Imported ValueHead", flush=True)
817
from trl.core import LengthSampler
918
from transformers import pipeline
10-
import torch
1119
import csv
1220

1321
warnings.filterwarnings("ignore", message="`resume_download` is deprecated")
1422
warnings.filterwarnings("ignore", message="Xformers is not installed correctly")
1523
warnings.filterwarnings("ignore", message="No dataset is provided.")
1624

25+
n_epochs = 20
26+
n_samples = 1
27+
1728
# Set device
1829
device = "cuda" if torch.cuda.is_available() else "cpu"
1930

2031
# Load tokenizer and model
32+
print(">>> Loading tokenizer and model...", flush=True)
2133
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
2234
tokenizer.pad_token = tokenizer.eos_token
2335
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2").to(device)
2436

37+
2538
# Load preprocessed IMDb data (negative reviews only)
2639
dataset = load_from_disk("tokenized_imdb_negative")
2740

2841
# Sample a few prompts for training
29-
#prompts = [tokenizer.decode(example["input_ids"][:64]) for example in dataset.select(range(64))]
30-
prompts = ["Generate a negative movie review:\n" + tokenizer.decode(example["input_ids"][:64]) # 12,500
31-
for example in dataset.select(range(50))] # 50 for minimal experience
42+
prompt_templates = [
43+
"This movie is so bad I had to leave. Continue the review:\n",
44+
"This film is a waste of time. Finish this:\n",
45+
"I hated everything about this movie. Explain why:\n",
46+
"The worst film ever. Expand the comment:\n",
47+
"The story was painfully boring. Go on:\n",
48+
"The direction and acting were terrible. Elaborate:\n",
49+
]
50+
51+
#prompts = [random.choice(prompt_templates) + "\n" + tokenizer.decode(example["input_ids"][:64]) # the number of tokens
52+
# for example in dataset.select(range(50))] # 50 for minimal experience, up to 12,500
3253

33-
print("prompts", prompts)
54+
prompts = [
55+
random.choice(prompt_templates)
56+
for _ in range(50)
57+
]
58+
59+
#print("prompts", prompts)
3460

3561
# Load reward model (IMDb classifier)
36-
reward_pipe = pipeline(
37-
"text-classification",
38-
model="wrmurray/roberta-base-finetuned-imdb",
39-
device=0 if device == "cuda" else -1
40-
)
62+
print(">>> Loading reward model...", flush=True)
63+
#sentiment_pipe = pipeline(
64+
# "text-classification",
65+
# model="textattack/roberta-base-imdb",
66+
# device=0 if torch.cuda.is_available() else -1
67+
#)
68+
#
69+
#toxicity_pipe = pipeline(
70+
# "text-classification",
71+
# model="unitary/toxic-bert",
72+
# device=0 if torch.cuda.is_available() else -1
73+
#)
74+
print(">>> Starting training...", flush=True)
4175

4276
# PPO config
4377
ppo_config = PPOConfig(
4478
model_name="gpt2",
45-
learning_rate=1.41e-5,
46-
batch_size=1,
47-
mini_batch_size=1,
79+
learning_rate=1e-5,
80+
batch_size=8,
81+
mini_batch_size=2,
4882
ppo_epochs=4,
4983
log_with="tensorboard",
5084
kl_penalty="kl",
51-
target_kl=6.0
85+
target_kl=0.2,
86+
ratio_threshold=20.0,
87+
early_stopping=True
5288
)
5389

5490
ppo_trainer = PPOTrainer(
@@ -59,40 +95,178 @@
5995

6096
log_file = open("ppo_logs/ppo_training_log.csv", "w", newline='')
6197
csv_writer = csv.writer(log_file)
62-
csv_writer.writerow(["epoch", "reward", "kl_divergence", "response"])
98+
csv_writer.writerow(["epoch", "step", "reward", "kl_divergence", "GPT-2 response", "PPO response"])
99+
100+
neg_patterns = [
101+
r"\bbad\b", r"\bterrible\b", r"\bawful\b", r"\bhorrible\b", r"\bpoor\b", r"\bboring\b", r"\bslow\b", r"\bdull\b",
102+
r"\bdisappointing\b", r"\bannoying\b", r"waste of time", r"not worth it", r"\bunbearable\b", r"\bmediocre\b",
103+
r"\bforgettable\b", r"\bflawed\b", r"\bunwatchable\b", r"\bgarbage\b", r"\btrash\b", r"\bmess\b", r"\bcheesy\b",
104+
r"\bcringe\b", r"\bregret\b", r"\bpathetic\b", r"\bsucks\b", r"\bstupid\b", r"\bnonsense\b", r"makes no sense",
105+
r"didn't like", r"couldn't finish", r"hated it", r"\bconfusing\b", r"\bpredictable\b", r"\bbuggy\b",
106+
r"\bridiculous\b", r"\babsurd\b", r"\boverrated\b", r"\bunderrated\b", r"\bincoherent\b", r"\bpainful\b",
107+
r"\bfake\b", r"\bpointless\b", r"\brepetitive\b", r"\bshallow\b", r"\bcliched\b", r"\blame\b", r"\blazy\b",
108+
r"\bbroken\b", r"poorly made", r"script was bad", r"bad acting", r"bad writing", r"plot holes",
109+
r"no plot", r"no development", r"no character arc", r"too long", r"dragged", r"drawn out",
110+
r"overacted", r"underacted", r"low budget", r"\bcheap\b", r"low quality", r"poor direction",
111+
r"\binconsistent\b", r"\bunbelievable\b", r"\bforced\b", r"bad pacing", r"terrible ending",
112+
r"no logic", r"makes you sleep", r"predictable twists", r"hate the ending", r"poor performance",
113+
r"fails to deliver", r"didn't work", r"had issues", r"not engaging", r"hard to watch", r"not funny",
114+
r"not scary", r"not interesting", r"annoying characters", r"\boverdone\b", r"\bpretentious\b",
115+
r"\bwannabe\b", r"\boveredited\b", r"\bunderwhelming\b", r"\bdisconnected\b", r"badly shot"
116+
]
117+
118+
119+
def repetition_reward(text: str) -> float:
120+
words = text.lower().split()
121+
if not words:
122+
return 0.0
123+
return len(set(words)) / len(words) # distinct-1
124+
125+
126+
def length_reward(text: str, min_len=10, max_len=64) -> float:
127+
length = len(text.split())
128+
return min(1.0, max(0.0, (length - min_len) / (max_len - min_len)))
129+
130+
131+
def negativity_reward(text: str) -> float:
132+
text_lower = text.lower()
133+
134+
match_count = sum(bool(re.search(pat, text_lower)) for pat in neg_patterns)
135+
return min(1.0, match_count / 2.0)
136+
137+
138+
def combo_reward(text: str, w_rep=0.05, w_len=0.05, w_neg=0.9) -> float:
139+
r1 = repetition_reward(text)
140+
r2 = length_reward(text)
141+
r3 = negativity_reward(text)
142+
total = w_rep * r1 + w_len * r2 + w_neg * r3
143+
return min(1.0, total)
144+
145+
all_queries = []
146+
all_responses = []
147+
all_rewards = []
63148

64149
# Training loop
65-
for epoch, prompt in enumerate(prompts): # epoch -> step: naming issue
66-
# Encode prompt
67-
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
68-
69-
# Generate response
70-
generation_output = model.generate(
71-
input_ids,
72-
max_new_tokens=64,
73-
pad_token_id=tokenizer.eos_token_id
74-
)
75-
response = tokenizer.decode(generation_output[0][input_ids.shape[-1]:], skip_special_tokens=True)
76-
77-
# Compute reward
78-
reward_output = reward_pipe(response)
79-
reward_score = reward_output[0]["score"]
80-
reward_tensor = torch.tensor(reward_score).to(device)
81-
rewards = [reward_tensor]
82-
83-
# PPO step
84-
query_tensor = tokenizer(prompt, return_tensors="pt").input_ids[0].to(device)
85-
response_tensor = tokenizer(response, return_tensors="pt").input_ids[0].to(device)
86-
ppo_trainer.step([query_tensor], [response_tensor], rewards)
87-
88-
train_stats = ppo_trainer.step([query_tensor], [response_tensor], rewards)
150+
for epoch in range(n_epochs):
151+
print(f"\n=== Epoch {epoch+1}/{n_epochs} ===")
89152

90-
kl_value = train_stats.get("kl", train_stats.get("objective/kl", None))
153+
for step, prompt in enumerate(prompts): # epoch -> step: naming issue
154+
rewards = []
155+
responses = []
156+
157+
for sample_idx in range(n_samples): # ***
158+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device).long()
159+
160+
with torch.no_grad():
161+
baseline_output = model.pretrained_model.generate( # ***
162+
input_ids,
163+
max_new_tokens=80,
164+
min_new_tokens=20,
165+
pad_token_id=tokenizer.eos_token_id,
166+
repetition_penalty=1.2
167+
)
168+
gpt2_response = tokenizer.decode( # ***
169+
baseline_output[0][input_ids.shape[-1]:],
170+
skip_special_tokens=True
171+
)
172+
173+
generation_output = model.generate(
174+
input_ids,
175+
max_new_tokens=80,
176+
min_new_tokens=10,
177+
pad_token_id=tokenizer.eos_token_id,
178+
repetition_penalty=1.2,
179+
do_sample=True,
180+
top_k=50,
181+
top_p=0.95,
182+
eos_token_id=None
183+
)
184+
185+
if generation_output.shape[-1] <= input_ids.shape[-1]:
186+
print("Empty generation. Skipping.")
187+
continue
188+
189+
response = tokenizer.decode(generation_output[0][input_ids.shape[-1]:], skip_special_tokens=True)
190+
191+
if len(response.strip()) == 0:
192+
print("Empty string response. Skipping.")
193+
continue
194+
195+
if "<a href=" in response or "http" in response:
196+
print("URL-like response. Skipping.")
197+
continue
198+
199+
responses.append(response)
200+
201+
#reward_output = reward_pipe(response)
202+
#reward_score = reward_output[0]["score"]
203+
reward_score = combo_reward(response)
204+
rewards.append(torch.tensor(reward_score).to(device))
205+
206+
# log single sample
207+
print(f"[Epoch {epoch+1} | Step {step+1}/{len(prompts)} | "
208+
f"Reward: {reward_score:.4f} | Response: {response[:80]}...", flush=True) #
209+
210+
if len(rewards) == 0:
211+
print(f"[Epoch {epoch+1} | Step {step+1}] Skipped: No valid response.")
212+
continue # skip this step
91213

92-
csv_writer.writerow([epoch + 1, reward_score, kl_value, response])
214+
avg_reward = torch.mean(torch.stack(rewards)) #
215+
best_response = responses[rewards.index(max(rewards))] #
216+
217+
218+
query_tensor = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
219+
query_tensor = {k: v.to(device) for k, v in query_tensor.items()}
220+
query_tensor["input_ids"] = query_tensor["input_ids"].long()
221+
222+
response_tensor = tokenizer(best_response, return_tensors="pt", padding=True, truncation=True)
223+
response_tensor = {k: v.to(device) for k, v in response_tensor.items()}
224+
response_tensor["input_ids"] = response_tensor["input_ids"].long()
225+
226+
all_queries.append(query_tensor["input_ids"].squeeze(0))
227+
all_responses.append(response_tensor["input_ids"].squeeze(0))
228+
all_rewards.append(avg_reward)
229+
230+
231+
print(f"Prompt[:100]: {prompt[:100]}") #
232+
print(f"Response[:100]: {response[:100]}") #
233+
234+
#decoded_query = tokenizer.decode(query_tensor.tolist(), skip_special_tokens=True)
235+
#decoded_response = tokenizer.decode(response_tensor.tolist(), skip_special_tokens=True)
236+
decoded_query = tokenizer.decode(query_tensor["input_ids"].squeeze(0).tolist(), skip_special_tokens=True)
237+
decoded_response = tokenizer.decode(response_tensor["input_ids"].squeeze(0).tolist(), skip_special_tokens=True)
238+
239+
print(f"Decoded query[:100]: {decoded_query[:100]}")
240+
print(f"Decoded PPO target response[:100]: {decoded_response[:100]}")
241+
242+
r1 = repetition_reward(response)
243+
r2 = length_reward(response)
244+
r3 = negativity_reward(response)
245+
combo = combo_reward(response)
246+
print(f"Reward components -> repetition: {r1:.2f}, length: {r2:.2f}, negativity: {r3:.2f}, combo: {combo:.2f}")
247+
248+
if len(all_queries) == ppo_config.batch_size:
249+
train_stats = ppo_trainer.step(all_queries, all_responses, all_rewards)
250+
all_queries, all_responses, all_rewards = [], [], []
251+
252+
csv_writer.writerow([
253+
epoch + 1, step + 1,
254+
avg_reward.item(),
255+
train_stats.get("kl", train_stats.get("objective/kl", None)),
256+
gpt2_response,
257+
best_response
258+
])
259+
#train_stats = ppo_trainer.step([query_tensor], [response_tensor], [avg_reward]) #
260+
#kl_value = train_stats.get("kl", train_stats.get("objective/kl", None))
261+
#kl_value = train_stats.get("kl", train_stats.get("objective/kl", None)) if 'train_stats' in locals() else None
262+
263+
264+
265+
#csv_writer.writerow([epoch + 1, step + 1, avg_reward.item(), kl_value, gpt2_response, best_response]) #
266+
#csv_writer.writerow([epoch + 1, step + 1, avg_reward.item(), kl_value, gpt2_response, best_response])
267+
if all_queries:
268+
train_stats = ppo_trainer.step(all_queries, all_responses, all_rewards)
93269

94-
# Log progress
95-
print(f"[{epoch+1}/{len(prompts)}] Reward: {reward_score:.4f} | Response: {response[:80]}...", flush=True)
96270

97271
print("Training complete.")
98272

0 commit comments

Comments
 (0)