|
1 | 1 | # run_ppo.py |
2 | | - |
| 2 | +print(">>> Start script...ff", flush=True) |
3 | 3 | import warnings |
| 4 | +import re |
| 5 | +import random |
| 6 | +print(">>> Imported warnings", flush=True) |
| 7 | +import torch |
| 8 | +print(">>> Imported torch", flush=True) |
4 | 9 | from transformers import GPT2LMHeadModel, GPT2Tokenizer |
| 10 | +print(">>> Imported transformers", flush=True) |
5 | 11 | from datasets import load_from_disk |
| 12 | +print(">>> Imported datasets", flush=True) |
6 | 13 | from trl import PPOTrainer, PPOConfig |
| 14 | +print(">>> Imported PPOTrainer", flush=True) |
7 | 15 | from trl import AutoModelForCausalLMWithValueHead |
| 16 | +print(">>> Imported ValueHead", flush=True) |
8 | 17 | from trl.core import LengthSampler |
9 | 18 | from transformers import pipeline |
10 | | -import torch |
11 | 19 | import csv |
12 | 20 |
|
13 | 21 | warnings.filterwarnings("ignore", message="`resume_download` is deprecated") |
14 | 22 | warnings.filterwarnings("ignore", message="Xformers is not installed correctly") |
15 | 23 | warnings.filterwarnings("ignore", message="No dataset is provided.") |
16 | 24 |
|
| 25 | +n_epochs = 20 |
| 26 | +n_samples = 1 |
| 27 | + |
17 | 28 | # Set device |
18 | 29 | device = "cuda" if torch.cuda.is_available() else "cpu" |
19 | 30 |
|
20 | 31 | # Load tokenizer and model |
| 32 | +print(">>> Loading tokenizer and model...", flush=True) |
21 | 33 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
22 | 34 | tokenizer.pad_token = tokenizer.eos_token |
23 | 35 | model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2").to(device) |
24 | 36 |
|
| 37 | + |
25 | 38 | # Load preprocessed IMDb data (negative reviews only) |
26 | 39 | dataset = load_from_disk("tokenized_imdb_negative") |
27 | 40 |
|
28 | 41 | # Sample a few prompts for training |
29 | | -#prompts = [tokenizer.decode(example["input_ids"][:64]) for example in dataset.select(range(64))] |
30 | | -prompts = ["Generate a negative movie review:\n" + tokenizer.decode(example["input_ids"][:64]) # 12,500 |
31 | | - for example in dataset.select(range(50))] # 50 for minimal experience |
| 42 | +prompt_templates = [ |
| 43 | + "This movie is so bad I had to leave. Continue the review:\n", |
| 44 | + "This film is a waste of time. Finish this:\n", |
| 45 | + "I hated everything about this movie. Explain why:\n", |
| 46 | + "The worst film ever. Expand the comment:\n", |
| 47 | + "The story was painfully boring. Go on:\n", |
| 48 | + "The direction and acting were terrible. Elaborate:\n", |
| 49 | +] |
| 50 | + |
| 51 | +#prompts = [random.choice(prompt_templates) + "\n" + tokenizer.decode(example["input_ids"][:64]) # the number of tokens |
| 52 | +# for example in dataset.select(range(50))] # 50 for minimal experience, up to 12,500 |
32 | 53 |
|
33 | | -print("prompts", prompts) |
| 54 | +prompts = [ |
| 55 | + random.choice(prompt_templates) |
| 56 | + for _ in range(50) |
| 57 | +] |
| 58 | + |
| 59 | +#print("prompts", prompts) |
34 | 60 |
|
35 | 61 | # Load reward model (IMDb classifier) |
36 | | -reward_pipe = pipeline( |
37 | | - "text-classification", |
38 | | - model="wrmurray/roberta-base-finetuned-imdb", |
39 | | - device=0 if device == "cuda" else -1 |
40 | | -) |
| 62 | +print(">>> Loading reward model...", flush=True) |
| 63 | +#sentiment_pipe = pipeline( |
| 64 | +# "text-classification", |
| 65 | +# model="textattack/roberta-base-imdb", |
| 66 | +# device=0 if torch.cuda.is_available() else -1 |
| 67 | +#) |
| 68 | +# |
| 69 | +#toxicity_pipe = pipeline( |
| 70 | +# "text-classification", |
| 71 | +# model="unitary/toxic-bert", |
| 72 | +# device=0 if torch.cuda.is_available() else -1 |
| 73 | +#) |
| 74 | +print(">>> Starting training...", flush=True) |
41 | 75 |
|
42 | 76 | # PPO config |
43 | 77 | ppo_config = PPOConfig( |
44 | 78 | model_name="gpt2", |
45 | | - learning_rate=1.41e-5, |
46 | | - batch_size=1, |
47 | | - mini_batch_size=1, |
| 79 | + learning_rate=1e-5, |
| 80 | + batch_size=8, |
| 81 | + mini_batch_size=2, |
48 | 82 | ppo_epochs=4, |
49 | 83 | log_with="tensorboard", |
50 | 84 | kl_penalty="kl", |
51 | | - target_kl=6.0 |
| 85 | + target_kl=0.2, |
| 86 | + ratio_threshold=20.0, |
| 87 | + early_stopping=True |
52 | 88 | ) |
53 | 89 |
|
54 | 90 | ppo_trainer = PPOTrainer( |
|
59 | 95 |
|
60 | 96 | log_file = open("ppo_logs/ppo_training_log.csv", "w", newline='') |
61 | 97 | csv_writer = csv.writer(log_file) |
62 | | -csv_writer.writerow(["epoch", "reward", "kl_divergence", "response"]) |
| 98 | +csv_writer.writerow(["epoch", "step", "reward", "kl_divergence", "GPT-2 response", "PPO response"]) |
| 99 | + |
| 100 | +neg_patterns = [ |
| 101 | + r"\bbad\b", r"\bterrible\b", r"\bawful\b", r"\bhorrible\b", r"\bpoor\b", r"\bboring\b", r"\bslow\b", r"\bdull\b", |
| 102 | + r"\bdisappointing\b", r"\bannoying\b", r"waste of time", r"not worth it", r"\bunbearable\b", r"\bmediocre\b", |
| 103 | + r"\bforgettable\b", r"\bflawed\b", r"\bunwatchable\b", r"\bgarbage\b", r"\btrash\b", r"\bmess\b", r"\bcheesy\b", |
| 104 | + r"\bcringe\b", r"\bregret\b", r"\bpathetic\b", r"\bsucks\b", r"\bstupid\b", r"\bnonsense\b", r"makes no sense", |
| 105 | + r"didn't like", r"couldn't finish", r"hated it", r"\bconfusing\b", r"\bpredictable\b", r"\bbuggy\b", |
| 106 | + r"\bridiculous\b", r"\babsurd\b", r"\boverrated\b", r"\bunderrated\b", r"\bincoherent\b", r"\bpainful\b", |
| 107 | + r"\bfake\b", r"\bpointless\b", r"\brepetitive\b", r"\bshallow\b", r"\bcliched\b", r"\blame\b", r"\blazy\b", |
| 108 | + r"\bbroken\b", r"poorly made", r"script was bad", r"bad acting", r"bad writing", r"plot holes", |
| 109 | + r"no plot", r"no development", r"no character arc", r"too long", r"dragged", r"drawn out", |
| 110 | + r"overacted", r"underacted", r"low budget", r"\bcheap\b", r"low quality", r"poor direction", |
| 111 | + r"\binconsistent\b", r"\bunbelievable\b", r"\bforced\b", r"bad pacing", r"terrible ending", |
| 112 | + r"no logic", r"makes you sleep", r"predictable twists", r"hate the ending", r"poor performance", |
| 113 | + r"fails to deliver", r"didn't work", r"had issues", r"not engaging", r"hard to watch", r"not funny", |
| 114 | + r"not scary", r"not interesting", r"annoying characters", r"\boverdone\b", r"\bpretentious\b", |
| 115 | + r"\bwannabe\b", r"\boveredited\b", r"\bunderwhelming\b", r"\bdisconnected\b", r"badly shot" |
| 116 | +] |
| 117 | + |
| 118 | + |
| 119 | +def repetition_reward(text: str) -> float: |
| 120 | + words = text.lower().split() |
| 121 | + if not words: |
| 122 | + return 0.0 |
| 123 | + return len(set(words)) / len(words) # distinct-1 |
| 124 | + |
| 125 | + |
| 126 | +def length_reward(text: str, min_len=10, max_len=64) -> float: |
| 127 | + length = len(text.split()) |
| 128 | + return min(1.0, max(0.0, (length - min_len) / (max_len - min_len))) |
| 129 | + |
| 130 | + |
| 131 | +def negativity_reward(text: str) -> float: |
| 132 | + text_lower = text.lower() |
| 133 | + |
| 134 | + match_count = sum(bool(re.search(pat, text_lower)) for pat in neg_patterns) |
| 135 | + return min(1.0, match_count / 2.0) |
| 136 | + |
| 137 | + |
| 138 | +def combo_reward(text: str, w_rep=0.05, w_len=0.05, w_neg=0.9) -> float: |
| 139 | + r1 = repetition_reward(text) |
| 140 | + r2 = length_reward(text) |
| 141 | + r3 = negativity_reward(text) |
| 142 | + total = w_rep * r1 + w_len * r2 + w_neg * r3 |
| 143 | + return min(1.0, total) |
| 144 | + |
| 145 | +all_queries = [] |
| 146 | +all_responses = [] |
| 147 | +all_rewards = [] |
63 | 148 |
|
64 | 149 | # Training loop |
65 | | -for epoch, prompt in enumerate(prompts): # epoch -> step: naming issue |
66 | | - # Encode prompt |
67 | | - input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
68 | | - |
69 | | - # Generate response |
70 | | - generation_output = model.generate( |
71 | | - input_ids, |
72 | | - max_new_tokens=64, |
73 | | - pad_token_id=tokenizer.eos_token_id |
74 | | - ) |
75 | | - response = tokenizer.decode(generation_output[0][input_ids.shape[-1]:], skip_special_tokens=True) |
76 | | - |
77 | | - # Compute reward |
78 | | - reward_output = reward_pipe(response) |
79 | | - reward_score = reward_output[0]["score"] |
80 | | - reward_tensor = torch.tensor(reward_score).to(device) |
81 | | - rewards = [reward_tensor] |
82 | | - |
83 | | - # PPO step |
84 | | - query_tensor = tokenizer(prompt, return_tensors="pt").input_ids[0].to(device) |
85 | | - response_tensor = tokenizer(response, return_tensors="pt").input_ids[0].to(device) |
86 | | - ppo_trainer.step([query_tensor], [response_tensor], rewards) |
87 | | - |
88 | | - train_stats = ppo_trainer.step([query_tensor], [response_tensor], rewards) |
| 150 | +for epoch in range(n_epochs): |
| 151 | + print(f"\n=== Epoch {epoch+1}/{n_epochs} ===") |
89 | 152 |
|
90 | | - kl_value = train_stats.get("kl", train_stats.get("objective/kl", None)) |
| 153 | + for step, prompt in enumerate(prompts): # epoch -> step: naming issue |
| 154 | + rewards = [] |
| 155 | + responses = [] |
| 156 | + |
| 157 | + for sample_idx in range(n_samples): # *** |
| 158 | + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device).long() |
| 159 | + |
| 160 | + with torch.no_grad(): |
| 161 | + baseline_output = model.pretrained_model.generate( # *** |
| 162 | + input_ids, |
| 163 | + max_new_tokens=80, |
| 164 | + min_new_tokens=20, |
| 165 | + pad_token_id=tokenizer.eos_token_id, |
| 166 | + repetition_penalty=1.2 |
| 167 | + ) |
| 168 | + gpt2_response = tokenizer.decode( # *** |
| 169 | + baseline_output[0][input_ids.shape[-1]:], |
| 170 | + skip_special_tokens=True |
| 171 | + ) |
| 172 | + |
| 173 | + generation_output = model.generate( |
| 174 | + input_ids, |
| 175 | + max_new_tokens=80, |
| 176 | + min_new_tokens=10, |
| 177 | + pad_token_id=tokenizer.eos_token_id, |
| 178 | + repetition_penalty=1.2, |
| 179 | + do_sample=True, |
| 180 | + top_k=50, |
| 181 | + top_p=0.95, |
| 182 | + eos_token_id=None |
| 183 | + ) |
| 184 | + |
| 185 | + if generation_output.shape[-1] <= input_ids.shape[-1]: |
| 186 | + print("Empty generation. Skipping.") |
| 187 | + continue |
| 188 | + |
| 189 | + response = tokenizer.decode(generation_output[0][input_ids.shape[-1]:], skip_special_tokens=True) |
| 190 | + |
| 191 | + if len(response.strip()) == 0: |
| 192 | + print("Empty string response. Skipping.") |
| 193 | + continue |
| 194 | + |
| 195 | + if "<a href=" in response or "http" in response: |
| 196 | + print("URL-like response. Skipping.") |
| 197 | + continue |
| 198 | + |
| 199 | + responses.append(response) |
| 200 | + |
| 201 | + #reward_output = reward_pipe(response) |
| 202 | + #reward_score = reward_output[0]["score"] |
| 203 | + reward_score = combo_reward(response) |
| 204 | + rewards.append(torch.tensor(reward_score).to(device)) |
| 205 | + |
| 206 | + # log single sample |
| 207 | + print(f"[Epoch {epoch+1} | Step {step+1}/{len(prompts)} | " |
| 208 | + f"Reward: {reward_score:.4f} | Response: {response[:80]}...", flush=True) # |
| 209 | + |
| 210 | + if len(rewards) == 0: |
| 211 | + print(f"[Epoch {epoch+1} | Step {step+1}] Skipped: No valid response.") |
| 212 | + continue # skip this step |
91 | 213 |
|
92 | | - csv_writer.writerow([epoch + 1, reward_score, kl_value, response]) |
| 214 | + avg_reward = torch.mean(torch.stack(rewards)) # |
| 215 | + best_response = responses[rewards.index(max(rewards))] # |
| 216 | + |
| 217 | + |
| 218 | + query_tensor = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) |
| 219 | + query_tensor = {k: v.to(device) for k, v in query_tensor.items()} |
| 220 | + query_tensor["input_ids"] = query_tensor["input_ids"].long() |
| 221 | + |
| 222 | + response_tensor = tokenizer(best_response, return_tensors="pt", padding=True, truncation=True) |
| 223 | + response_tensor = {k: v.to(device) for k, v in response_tensor.items()} |
| 224 | + response_tensor["input_ids"] = response_tensor["input_ids"].long() |
| 225 | + |
| 226 | + all_queries.append(query_tensor["input_ids"].squeeze(0)) |
| 227 | + all_responses.append(response_tensor["input_ids"].squeeze(0)) |
| 228 | + all_rewards.append(avg_reward) |
| 229 | + |
| 230 | + |
| 231 | + print(f"Prompt[:100]: {prompt[:100]}") # |
| 232 | + print(f"Response[:100]: {response[:100]}") # |
| 233 | + |
| 234 | + #decoded_query = tokenizer.decode(query_tensor.tolist(), skip_special_tokens=True) |
| 235 | + #decoded_response = tokenizer.decode(response_tensor.tolist(), skip_special_tokens=True) |
| 236 | + decoded_query = tokenizer.decode(query_tensor["input_ids"].squeeze(0).tolist(), skip_special_tokens=True) |
| 237 | + decoded_response = tokenizer.decode(response_tensor["input_ids"].squeeze(0).tolist(), skip_special_tokens=True) |
| 238 | + |
| 239 | + print(f"Decoded query[:100]: {decoded_query[:100]}") |
| 240 | + print(f"Decoded PPO target response[:100]: {decoded_response[:100]}") |
| 241 | + |
| 242 | + r1 = repetition_reward(response) |
| 243 | + r2 = length_reward(response) |
| 244 | + r3 = negativity_reward(response) |
| 245 | + combo = combo_reward(response) |
| 246 | + print(f"Reward components -> repetition: {r1:.2f}, length: {r2:.2f}, negativity: {r3:.2f}, combo: {combo:.2f}") |
| 247 | + |
| 248 | + if len(all_queries) == ppo_config.batch_size: |
| 249 | + train_stats = ppo_trainer.step(all_queries, all_responses, all_rewards) |
| 250 | + all_queries, all_responses, all_rewards = [], [], [] |
| 251 | + |
| 252 | + csv_writer.writerow([ |
| 253 | + epoch + 1, step + 1, |
| 254 | + avg_reward.item(), |
| 255 | + train_stats.get("kl", train_stats.get("objective/kl", None)), |
| 256 | + gpt2_response, |
| 257 | + best_response |
| 258 | + ]) |
| 259 | + #train_stats = ppo_trainer.step([query_tensor], [response_tensor], [avg_reward]) # |
| 260 | + #kl_value = train_stats.get("kl", train_stats.get("objective/kl", None)) |
| 261 | + #kl_value = train_stats.get("kl", train_stats.get("objective/kl", None)) if 'train_stats' in locals() else None |
| 262 | + |
| 263 | + |
| 264 | + |
| 265 | + #csv_writer.writerow([epoch + 1, step + 1, avg_reward.item(), kl_value, gpt2_response, best_response]) # |
| 266 | + #csv_writer.writerow([epoch + 1, step + 1, avg_reward.item(), kl_value, gpt2_response, best_response]) |
| 267 | +if all_queries: |
| 268 | + train_stats = ppo_trainer.step(all_queries, all_responses, all_rewards) |
93 | 269 |
|
94 | | - # Log progress |
95 | | - print(f"[{epoch+1}/{len(prompts)}] Reward: {reward_score:.4f} | Response: {response[:80]}...", flush=True) |
96 | 270 |
|
97 | 271 | print("Training complete.") |
98 | 272 |
|
|
0 commit comments