|
| 1 | +import pandas as pd |
| 2 | +import numpy as np |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import torch |
| 5 | +from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline |
| 6 | + |
| 7 | +# *** Set device *** |
| 8 | +device = "cuda" if torch.cuda.is_available() else "cpu" |
| 9 | +# *** |
| 10 | + |
| 11 | +print(">>> Using device:", device) |
| 12 | + |
| 13 | +# *** Load PPO training log from ppo_logs/ppo_training_log.csv *** |
| 14 | +df = pd.read_csv("ppo_logs/ppo_training_log.csv") |
| 15 | +# *** |
| 16 | + |
| 17 | +# *** Load sentiment classifier for evaluation *** |
| 18 | +sentiment_pipe = pipeline( |
| 19 | + "text-classification", |
| 20 | + model="wrmurray/roberta-base-finetuned-imdb", |
| 21 | + device=0 if device=="cuda" else -1 |
| 22 | +) |
| 23 | +# *** |
| 24 | + |
| 25 | +# *** Load GPT-2 model and tokenizer for perplexity evaluation *** |
| 26 | +ppl_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device) |
| 27 | +ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| 28 | +ppl_tokenizer.pad_token = ppl_tokenizer.eos_token |
| 29 | +# *** |
| 30 | + |
| 31 | +# *** Define function to compute perplexity for a given text *** |
| 32 | +def compute_perplexity(text): |
| 33 | + inputs = ppl_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
| 34 | + input_ids = inputs.input_ids.to(device) |
| 35 | + with torch.no_grad(): |
| 36 | + loss = ppl_model(input_ids, labels=input_ids).loss |
| 37 | + return torch.exp(loss).item() |
| 38 | +# *** |
| 39 | + |
| 40 | +# *** Define function to compute distinct-n diversity *** |
| 41 | +def distinct_n(texts, n): |
| 42 | + total_ngrams = 0 |
| 43 | + unique_ngrams = set() |
| 44 | + for t in texts: |
| 45 | + tokens = t.split() |
| 46 | + total_ngrams += max(0, len(tokens) - n + 1) |
| 47 | + for i in range(len(tokens) - n + 1): |
| 48 | + unique_ngrams.add(tuple(tokens[i:i+n])) |
| 49 | + return len(unique_ngrams) / total_ngrams if total_ngrams > 0 else 0 |
| 50 | +# *** |
| 51 | + |
| 52 | +# *** Evaluate Sentiment Accuracy using classifier on each response *** |
| 53 | +sentiment_labels = [sentiment_pipe(response)[0]["label"] for response in df["response"]] |
| 54 | +sentiment_accuracy = np.mean([1 if label == "NEGATIVE" else 0 for label in sentiment_labels]) |
| 55 | +# *** |
| 56 | + |
| 57 | +# *** Compute perplexity for each response *** |
| 58 | +perplexities = [compute_perplexity(response) for response in df["response"]] |
| 59 | +avg_perplexity = np.mean(perplexities) |
| 60 | +# *** |
| 61 | + |
| 62 | +# *** Compute diversity (Distinct-1 and Distinct-2) over all responses *** |
| 63 | +dist1 = distinct_n(df["response"], 1) |
| 64 | +dist2 = distinct_n(df["response"], 2) |
| 65 | +# *** |
| 66 | + |
| 67 | +# *** Plot reward progression *** |
| 68 | +plt.figure(figsize=(8, 4)) |
| 69 | +plt.plot(df["epoch"], df["reward"], marker="o") |
| 70 | +plt.title("Reward Progression over Epochs") |
| 71 | +plt.xlabel("Epoch") |
| 72 | +plt.ylabel("Reward") |
| 73 | +plt.grid(True) |
| 74 | +plt.tight_layout() |
| 75 | +plt.savefig("metrics_results/reward_progression.png") |
| 76 | +plt.close() |
| 77 | +# *** |
| 78 | + |
| 79 | +# *** Save evaluation metrics summary to a text file *** |
| 80 | +with open("metrics_results/eval_metrics_summary.txt", "w") as f: |
| 81 | + f.write("Evaluation Metrics Summary\n") |
| 82 | + f.write("--------------------------\n") |
| 83 | + f.write(f"Sentiment Accuracy: {sentiment_accuracy * 100:.2f}%\n") |
| 84 | + f.write(f"Average Perplexity: {avg_perplexity:.2f}\n") |
| 85 | + f.write(f"Distinct-1: {dist1:.4f}\n") |
| 86 | + f.write(f"Distinct-2: {dist2:.4f}\n") |
| 87 | +# *** |
| 88 | + |
| 89 | + |
| 90 | + if "kl_divergence" in df.columns: |
| 91 | + df[["epoch", "reward", "kl_divergence"]].to_csv("metrics_results/kl_vs_reward.csv", index=False) |
| 92 | + f.write("\nKL vs Reward data saved to metrics_results/kl_vs_reward.csv\n") |
| 93 | + |
| 94 | + |
| 95 | +print("Evaluation complete. Metrics saved to 'ppo_logs/eval_metrics_summary.txt' and reward progression plotted to 'ppo_logs/reward_progression.png'.") |
| 96 | + |
0 commit comments