Skip to content

Commit 0d4b5d9

Browse files
Add files via upload
1 parent 972c82d commit 0d4b5d9

File tree

6 files changed

+203
-0
lines changed

6 files changed

+203
-0
lines changed

Mincheol/eval_metrics.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import pandas as pd
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
import torch
5+
from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline
6+
7+
# *** Set device ***
8+
device = "cuda" if torch.cuda.is_available() else "cpu"
9+
# ***
10+
11+
print(">>> Using device:", device)
12+
13+
# *** Load PPO training log from ppo_logs/ppo_training_log.csv ***
14+
df = pd.read_csv("ppo_logs/ppo_training_log.csv")
15+
# ***
16+
17+
# *** Load sentiment classifier for evaluation ***
18+
sentiment_pipe = pipeline(
19+
"text-classification",
20+
model="wrmurray/roberta-base-finetuned-imdb",
21+
device=0 if device=="cuda" else -1
22+
)
23+
# ***
24+
25+
# *** Load GPT-2 model and tokenizer for perplexity evaluation ***
26+
ppl_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
27+
ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
28+
ppl_tokenizer.pad_token = ppl_tokenizer.eos_token
29+
# ***
30+
31+
# *** Define function to compute perplexity for a given text ***
32+
def compute_perplexity(text):
33+
inputs = ppl_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
34+
input_ids = inputs.input_ids.to(device)
35+
with torch.no_grad():
36+
loss = ppl_model(input_ids, labels=input_ids).loss
37+
return torch.exp(loss).item()
38+
# ***
39+
40+
# *** Define function to compute distinct-n diversity ***
41+
def distinct_n(texts, n):
42+
total_ngrams = 0
43+
unique_ngrams = set()
44+
for t in texts:
45+
tokens = t.split()
46+
total_ngrams += max(0, len(tokens) - n + 1)
47+
for i in range(len(tokens) - n + 1):
48+
unique_ngrams.add(tuple(tokens[i:i+n]))
49+
return len(unique_ngrams) / total_ngrams if total_ngrams > 0 else 0
50+
# ***
51+
52+
# *** Evaluate Sentiment Accuracy using classifier on each response ***
53+
sentiment_labels = [sentiment_pipe(response)[0]["label"] for response in df["response"]]
54+
sentiment_accuracy = np.mean([1 if label == "NEGATIVE" else 0 for label in sentiment_labels])
55+
# ***
56+
57+
# *** Compute perplexity for each response ***
58+
perplexities = [compute_perplexity(response) for response in df["response"]]
59+
avg_perplexity = np.mean(perplexities)
60+
# ***
61+
62+
# *** Compute diversity (Distinct-1 and Distinct-2) over all responses ***
63+
dist1 = distinct_n(df["response"], 1)
64+
dist2 = distinct_n(df["response"], 2)
65+
# ***
66+
67+
# *** Plot reward progression ***
68+
plt.figure(figsize=(8, 4))
69+
plt.plot(df["epoch"], df["reward"], marker="o")
70+
plt.title("Reward Progression over Epochs")
71+
plt.xlabel("Epoch")
72+
plt.ylabel("Reward")
73+
plt.grid(True)
74+
plt.tight_layout()
75+
plt.savefig("metrics_results/reward_progression.png")
76+
plt.close()
77+
# ***
78+
79+
# *** Save evaluation metrics summary to a text file ***
80+
with open("metrics_results/eval_metrics_summary.txt", "w") as f:
81+
f.write("Evaluation Metrics Summary\n")
82+
f.write("--------------------------\n")
83+
f.write(f"Sentiment Accuracy: {sentiment_accuracy * 100:.2f}%\n")
84+
f.write(f"Average Perplexity: {avg_perplexity:.2f}\n")
85+
f.write(f"Distinct-1: {dist1:.4f}\n")
86+
f.write(f"Distinct-2: {dist2:.4f}\n")
87+
# ***
88+
89+
90+
if "kl_divergence" in df.columns:
91+
df[["epoch", "reward", "kl_divergence"]].to_csv("metrics_results/kl_vs_reward.csv", index=False)
92+
f.write("\nKL vs Reward data saved to metrics_results/kl_vs_reward.csv\n")
93+
94+
95+
print("Evaluation complete. Metrics saved to 'ppo_logs/eval_metrics_summary.txt' and reward progression plotted to 'ppo_logs/reward_progression.png'.")
96+

Mincheol/eval_metrics.slurm

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=eval_metrics
3+
#SBATCH --partition=gpu
4+
#SBATCH --gres=gpu:1
5+
#SBATCH --ntasks=1
6+
#SBATCH --cpus-per-task=4
7+
#SBATCH --time=00:10:00
8+
#SBATCH --mem=8G
9+
#SBATCH --output=metrics_results/eval_metrics.out
10+
#SBATCH --error=metrics_results/eval_metrics.err
11+
12+
ml GCCcore/13.3.0
13+
ml Miniconda3/23.10.0-1
14+
source ~/.bashrc
15+
conda activate grpo
16+
17+
cd /scratch/user/mincheolseong/GRPO_project/ECEN743-GRPO-Project-Proposal/mincheol_runs
18+
python eval_metrics.py
19+
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
/scratch/user/mincheolseong/.conda/envs/grpo/lib/python3.10/site-packages/huggingface_hub/file_download.py:896: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
2+
warnings.warn(
3+
/scratch/user/mincheolseong/.conda/envs/grpo/lib/python3.10/site-packages/huggingface_hub/file_download.py:896: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
4+
warnings.warn(

Mincheol/preload_reward_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# preload_reward_model.py
2+
3+
from transformers import pipeline
4+
5+
print("Downloading RoBERTa IMDb classifier to cache...")
6+
pipe = pipeline("text-classification", model="wrmurray/roberta-base-finetuned-imdb")
7+
print("Done. Model is now cached.")
8+

Mincheol/preprocess_imdb.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# preprocess_imdb.py
2+
3+
from datasets import load_dataset
4+
from transformers import GPT2Tokenizer
5+
import os
6+
7+
# save directory
8+
SAVE_PATH = "tokenized_imdb_negative"
9+
os.makedirs(SAVE_PATH, exist_ok=True)
10+
11+
def main():
12+
# 1. IMDbdata load
13+
print("▶ Loading IMDb dataset...")
14+
dataset = load_dataset("imdb")
15+
16+
# 2. Filtering negative reviews
17+
print("▶ Filtering negative reviews...")
18+
negative_reviews = dataset["train"].filter(lambda x: x["label"] == 0)
19+
20+
# 3. Converting the form of prompt-completion
21+
def make_prompt_completion(example):
22+
prompt = "Generate a negative movie review:\n"
23+
completion = example["text"]
24+
return {
25+
"prompt": prompt,
26+
"completion": completion,
27+
}
28+
29+
formatted = negative_reviews.map(make_prompt_completion)
30+
31+
# 4. Load Tokenizer
32+
print("▶ Loading tokenizer...")
33+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
34+
tokenizer.pad_token = tokenizer.eos_token # GPT-2 doesn't have pad token
35+
36+
# 5. Tokenize
37+
def tokenize(example):
38+
prompt_ids = tokenizer.encode(example["prompt"], truncation=True, max_length=64)
39+
completion_ids = tokenizer.encode(example["completion"], truncation=True, max_length=128)
40+
input_ids = prompt_ids + completion_ids
41+
attention_mask = [1] * len(input_ids)
42+
return {
43+
"input_ids": input_ids,
44+
"attention_mask": attention_mask,
45+
}
46+
47+
print("▶ Tokenizing...")
48+
tokenized = formatted.map(tokenize, remove_columns=["text", "label", "prompt", "completion"])
49+
50+
# 6. Save
51+
print(f"Saving to: {SAVE_PATH}")
52+
tokenized.save_to_disk(SAVE_PATH)
53+
print(" Done.")
54+
55+
if __name__ == "__main__":
56+
main()
57+

Mincheol/run_ppo.slurm

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=ppo_gpt2
3+
#SBATCH --partition=gpu
4+
#SBATCH --gres=gpu:1
5+
#SBATCH --ntasks=1
6+
#SBATCH --cpus-per-task=8
7+
#SBATCH --time=02:00:00
8+
#SBATCH --mem=16G
9+
#SBATCH --output=ppo_logs/ppo_run.out
10+
#SBATCH --error=ppo_logs/ppo_run.err
11+
12+
ml GCCcore/13.3.0
13+
ml Miniconda3/23.10.0-1
14+
source ~/.bashrc
15+
conda activate grpo
16+
17+
cd /scratch/user/mincheolseong/GRPO_project/ECEN743-GRPO-Project-Proposal/mincheol_runs
18+
python run_ppo.py
19+

0 commit comments

Comments
 (0)