-
Notifications
You must be signed in to change notification settings - Fork 88
[skyrl-gym] GSM8k - LLM Judge example #74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 14 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
d7f774b
merge
erictang000 0c6d467
x
lynnliu030 3c6cc4c
x
lynnliu030 5c310cc
x
lynnliu030 1509ff5
x
lynnliu030 ac91847
x
lynnliu030 fb3a8bf
Merge remote-tracking branch 'origin/main' into shu/merge-llm-judge
lynnliu030 2fa62c2
x
lynnliu030 70c1271
Update skyrl-train/examples/llm_as_a_judge/env.py
lynnliu030 6eef2eb
Update skyrl-train/skyrl_train/config/skyrl_gym_config/default.yaml
lynnliu030 cab9e41
x
SumanthRH ee87295
x
SumanthRH fd5fd84
use new dataset
SumanthRH 4d75b7c
x
SumanthRH 1c86c25
fix
lynnliu030 5e9d9f6
x
lynnliu030 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
OPENAI_API_KEY="<openai_api_key>" | ||
# optionally, enter wandb if logging with wandb | ||
# WANDB_API_KEY=<wandb_api_key> |
90 changes: 90 additions & 0 deletions
90
skyrl-train/examples/llm_as_a_judge/gsm8k_dataset_judge_env.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Preprocess the GSM8k dataset to parquet format | ||
""" | ||
|
||
import argparse | ||
import re | ||
import os | ||
|
||
import datasets | ||
|
||
|
||
def extract_solution(solution_str): | ||
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) | ||
assert solution is not None | ||
final_solution = solution.group(0) | ||
final_solution = final_solution.split("#### ")[1].replace(",", "") | ||
return final_solution | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--output_dir", default="~/data/gsm8k_llm_judge") | ||
|
||
args = parser.parse_args() | ||
|
||
args.output_dir = os.path.expanduser(args.output_dir) | ||
|
||
data_source = "openai/gsm8k" | ||
|
||
dataset = datasets.load_dataset(data_source, "main") | ||
|
||
train_dataset = dataset["train"] | ||
val_dataset = dataset["test"] | ||
|
||
instruction_following = 'Let\'s think step by step and output the final answer after "####".' | ||
|
||
# add a row to each data item that represents a unique id | ||
def make_map_fn(split): | ||
def process_fn(example, idx): | ||
question_raw = example.pop("question") | ||
|
||
question = question_raw + " " + instruction_following | ||
|
||
answer_raw = example.pop("answer") | ||
solution = extract_solution(answer_raw) | ||
data = { | ||
"data_source": data_source, | ||
"prompt": [ | ||
{ | ||
"role": "user", | ||
"content": question, | ||
} | ||
], | ||
# TODO: just repeating the full data preprocess script for a single env change isn't very convenient. | ||
"env_class": "llm_as_a_judge", | ||
"reward_spec": { | ||
"method": "rule", | ||
"ground_truth": solution, | ||
}, | ||
"extra_info": { | ||
"split": split, | ||
"index": idx, | ||
"answer": answer_raw, | ||
"question": question_raw, | ||
}, | ||
} | ||
return data | ||
|
||
return process_fn | ||
|
||
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) | ||
val_dataset = val_dataset.map(function=make_map_fn("test"), with_indices=True) | ||
|
||
output_dir = args.output_dir | ||
os.makedirs(output_dir, exist_ok=True) | ||
train_dataset.to_parquet(os.path.join(output_dir, "train.parquet")) | ||
val_dataset.to_parquet(os.path.join(output_dir, "validation.parquet")) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput | ||
from typing import Any | ||
from typing import Dict | ||
from omegaconf import DictConfig | ||
from openai import OpenAI | ||
import os | ||
import re | ||
|
||
PROMPT = """ | ||
You are a strict math evaluation assistant. | ||
|
||
Compare the following **gold** and **predicted** math solutions. | ||
Determine if the predicted solution follows valid reasoning and reaches the correct final answer, even if the explanation differs in wording. | ||
|
||
Rules: | ||
- Only answer "1" if the predicted solution is mathematically correct and leads to the same final answer as the gold solution. | ||
- Otherwise, answer "0". | ||
- Do not include any explanation or extra text—output only a single character: "1" or "0". | ||
""" | ||
|
||
|
||
class GSM8kLLMJudgeEnv(BaseTextEnv): | ||
""" | ||
Example implementtion of GSM8k environment with LLM as judge. | ||
|
||
Use LLM as judge to evaluate the answer similarity with the ground truth. | ||
""" | ||
|
||
def __init__(self, env_config: DictConfig, extras: Dict[str, Any] = {}): | ||
super().__init__() | ||
|
||
assert "reward_spec" in extras, "reward_spec field is required" | ||
assert "ground_truth" in extras["reward_spec"], "ground_truth is required in reward_spec field" | ||
self.ground_truth = extras["reward_spec"]["ground_truth"] | ||
|
||
# Set up OpenAI client | ||
openai_api_key = os.getenv("OPENAI_API_KEY") | ||
if openai_api_key is None: | ||
raise ValueError("`OPENAI_API_KEY` must be set for Llm as a judge env") | ||
self.llm_judge_client = OpenAI(base_url=env_config.base_url, api_key=openai_api_key) | ||
self.model = env_config.model | ||
|
||
def _get_reward(self, action: str) -> float: | ||
message = PROMPT + f"\n\nGOLD SOLUTION:\n{self.ground_truth}\n\nPREDICTED SOLUTION:\n{action}\n\nAnswer:" | ||
|
||
try: | ||
response = self.llm_judge_client.chat.completions.create( | ||
model=self.model, messages=[{"role": "user", "content": message}] | ||
) | ||
reply = response.choices[0].message.content.strip() | ||
|
||
# Try to parse score from "### Final Score: x" | ||
match = re.search(r"### Final Score:\s*([01](?:\.0)?)", reply) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why try to parse from "### Final Score: x" since we explicitly prompted it to only return 0 or 1? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the prompt! |
||
if match: | ||
return float(match.group(1)) | ||
|
||
# Fallback: raw "1" or "0" | ||
if reply.strip() in {"1", "0"}: | ||
return float(reply.strip()) | ||
|
||
print(f"Unrecognized reward output: {reply}") | ||
return 0.0 | ||
|
||
except Exception as e: | ||
print(f"LLM Judge error: {type(e).__name__}: {e}") | ||
return 0.0 | ||
|
||
def step(self, action: str) -> BaseTextEnvStepOutput: | ||
done = True | ||
reward = self._get_reward(action) | ||
|
||
return BaseTextEnvStepOutput(observations=[], reward=reward, done=done, metadata={}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
""" | ||
uv run --isolated --extra vllm -m examples.llm_as_a_judge.main_llm_judge | ||
""" | ||
|
||
import ray | ||
import hydra | ||
from omegaconf import DictConfig | ||
from skyrl_train.utils import initialize_ray | ||
from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg | ||
from skyrl_gym.envs import register | ||
|
||
|
||
@ray.remote(num_cpus=1) | ||
def skyrl_entrypoint(cfg: DictConfig): | ||
# Register the multiply environment inside the entrypoint task (no need to modify the skyrl-gym package). | ||
lynnliu030 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
register( | ||
id="llm_as_a_judge", | ||
entry_point="examples.llm_as_a_judge.llm_judge_env:GSM8kLLMJudgeEnv", | ||
) | ||
|
||
# make sure that the training loop is not run on the head node. | ||
exp = BasePPOExp(cfg) | ||
exp.run() | ||
|
||
|
||
@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) | ||
def main(cfg: DictConfig) -> None: | ||
# validate the arguments | ||
validate_cfg(cfg) | ||
|
||
initialize_ray(cfg) | ||
ray.get(skyrl_entrypoint.remote(cfg)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
set -x | ||
|
||
# Colocated GRPO training+generation for Qwen2.5-Coder-7B-Instruct on SkyRL-SQL-653 data. | ||
lynnliu030 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Uses 1 node with 8 GPUs. | ||
lynnliu030 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# uv run examples/llm_as_a_judge/gsm8k_dataset_judge_env.py --output_dir $HOME/data/gsm8k_llm_judge | ||
# add OPENAI_API_KEY and WANDB_API_KEY to .env.llm_judge | ||
# bash examples/llm_as_a_judge/run_llm_judge.sh | ||
|
||
DATA_DIR="$HOME/data/gsm8k_llm_judge" | ||
CKPT_PATH="$HOME/ckpts/llm_judge" | ||
|
||
NUM_GPUS=4 | ||
NUM_INFERENCE_ENGINES=4 | ||
TP_SIZE=1 | ||
LOGGER=wandb | ||
|
||
# We use a smaller batch size here for demonstration | ||
uv run --isolated --extra vllm --env-file .env.llm_judge -m examples.llm_as_a_judge.main_llm_judge \ | ||
data.train_data="['$DATA_DIR/train.parquet']" \ | ||
data.val_data="['$DATA_DIR/validation.parquet']" \ | ||
trainer.algorithm.advantage_estimator="grpo" \ | ||
trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ | ||
trainer.placement.colocate_all=true \ | ||
trainer.strategy=fsdp2 \ | ||
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ | ||
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ | ||
generator.num_inference_engines=$NUM_INFERENCE_ENGINES \ | ||
generator.inference_engine_tensor_parallel_size=$TP_SIZE \ | ||
trainer.epochs=20 \ | ||
trainer.eval_batch_size=32 \ | ||
trainer.eval_before_train=false \ | ||
trainer.eval_interval=5 \ | ||
trainer.update_epochs_per_batch=1 \ | ||
trainer.train_batch_size=32 \ | ||
trainer.policy_mini_batch_size=32 \ | ||
trainer.micro_forward_batch_size_per_gpu=40 \ | ||
trainer.micro_train_batch_size_per_gpu=40 \ | ||
trainer.ckpt_interval=10 \ | ||
trainer.max_prompt_length=512 \ | ||
generator.sampling_params.max_generate_length=1024 \ | ||
trainer.policy.optimizer_config.lr=1.0e-6 \ | ||
trainer.algorithm.use_kl_loss=true \ | ||
generator.backend=vllm \ | ||
generator.run_engines_locally=true \ | ||
generator.weight_sync_backend=nccl \ | ||
generator.async_engine=true \ | ||
generator.batched=true \ | ||
generator.n_samples_per_prompt=5 \ | ||
generator.gpu_memory_utilization=0.8 \ | ||
trainer.logger="$LOGGER" \ | ||
trainer.project_name="gsm8k" \ | ||
trainer.run_name="gsm8k_llm_as_a_judge" \ | ||
trainer.resume_mode=null \ | ||
trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \ | ||
environment.env_class=llm_as_a_judge \ | ||
environment.skyrl_gym.llm_as_a_judge.model="gpt-4o-mini" \ | ||
lynnliu030 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
$@ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,16 @@ | ||
# @package environment.skyrl_gym | ||
# number of background workers for env step calls. Set to 0 to disable background workers. | ||
max_env_workers: 32 | ||
|
||
text2sql: | ||
db_path: "/home/ray/default/sql_data" | ||
|
||
llm_as_a_judge: | ||
model: "gpt-4o-mini" | ||
base_url: null # or a local endpoint: http://localhost:8000/v1 | ||
|
||
search: | ||
log_requests: false | ||
search_url: "http://127.0.0.1:8000/retrieve" | ||
topk: 3 | ||
timeout: 30 | ||
timeout: 30 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.