Skip to content
3 changes: 3 additions & 0 deletions skyrl-train/.env.llm_judge
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
OPENAI_API_KEY="<openai_api_key>"
# optionally, enter wandb if logging with wandb
# WANDB_API_KEY=<wandb_api_key>
90 changes: 90 additions & 0 deletions skyrl-train/examples/llm_as_a_judge/gsm8k_dataset_judge_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the GSM8k dataset to parquet format
"""

import argparse
import re
import os

import datasets


def extract_solution(solution_str):
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
assert solution is not None
final_solution = solution.group(0)
final_solution = final_solution.split("#### ")[1].replace(",", "")
return final_solution


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", default="~/data/gsm8k_llm_judge")

args = parser.parse_args()

args.output_dir = os.path.expanduser(args.output_dir)

data_source = "openai/gsm8k"

dataset = datasets.load_dataset(data_source, "main")

train_dataset = dataset["train"]
val_dataset = dataset["test"]

instruction_following = 'Let\'s think step by step and output the final answer after "####".'

# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
question_raw = example.pop("question")

question = question_raw + " " + instruction_following

answer_raw = example.pop("answer")
solution = extract_solution(answer_raw)
data = {
"data_source": data_source,
"prompt": [
{
"role": "user",
"content": question,
}
],
# TODO: just repeating the full data preprocess script for a single env change isn't very convenient.
"env_class": "llm_as_a_judge",
"reward_spec": {
"method": "rule",
"ground_truth": solution,
},
"extra_info": {
"split": split,
"index": idx,
"answer": answer_raw,
"question": question_raw,
},
}
return data

return process_fn

train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
val_dataset = val_dataset.map(function=make_map_fn("test"), with_indices=True)

output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
train_dataset.to_parquet(os.path.join(output_dir, "train.parquet"))
val_dataset.to_parquet(os.path.join(output_dir, "validation.parquet"))
72 changes: 72 additions & 0 deletions skyrl-train/examples/llm_as_a_judge/llm_judge_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput
from typing import Any
from typing import Dict
from omegaconf import DictConfig
from openai import OpenAI
import os
import re

PROMPT = """
You are a strict math evaluation assistant.

Compare the following **gold** and **predicted** math solutions.
Determine if the predicted solution follows valid reasoning and reaches the correct final answer, even if the explanation differs in wording.

Rules:
- Only answer "1" if the predicted solution is mathematically correct and leads to the same final answer as the gold solution.
- Otherwise, answer "0".
- Do not include any explanation or extra text—output only a single character: "1" or "0".
"""


class GSM8kLLMJudgeEnv(BaseTextEnv):
"""
Example implementtion of GSM8k environment with LLM as judge.

Use LLM as judge to evaluate the answer similarity with the ground truth.
"""

def __init__(self, env_config: DictConfig, extras: Dict[str, Any] = {}):
super().__init__()

assert "reward_spec" in extras, "reward_spec field is required"
assert "ground_truth" in extras["reward_spec"], "ground_truth is required in reward_spec field"
self.ground_truth = extras["reward_spec"]["ground_truth"]

# Set up OpenAI client
openai_api_key = os.getenv("OPENAI_API_KEY")
if openai_api_key is None:
raise ValueError("`OPENAI_API_KEY` must be set for Llm as a judge env")
self.llm_judge_client = OpenAI(base_url=env_config.base_url, api_key=openai_api_key)
self.model = env_config.model

def _get_reward(self, action: str) -> float:
message = PROMPT + f"\n\nGOLD SOLUTION:\n{self.ground_truth}\n\nPREDICTED SOLUTION:\n{action}\n\nAnswer:"

try:
response = self.llm_judge_client.chat.completions.create(
model=self.model, messages=[{"role": "user", "content": message}]
)
reply = response.choices[0].message.content.strip()

# Try to parse score from "### Final Score: x"
match = re.search(r"### Final Score:\s*([01](?:\.0)?)", reply)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why try to parse from "### Final Score: x" since we explicitly prompted it to only return 0 or 1?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the prompt!

if match:
return float(match.group(1))

# Fallback: raw "1" or "0"
if reply.strip() in {"1", "0"}:
return float(reply.strip())

print(f"Unrecognized reward output: {reply}")
return 0.0

except Exception as e:
print(f"LLM Judge error: {type(e).__name__}: {e}")
return 0.0

def step(self, action: str) -> BaseTextEnvStepOutput:
done = True
reward = self._get_reward(action)

return BaseTextEnvStepOutput(observations=[], reward=reward, done=done, metadata={})
36 changes: 36 additions & 0 deletions skyrl-train/examples/llm_as_a_judge/main_llm_judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
uv run --isolated --extra vllm -m examples.llm_as_a_judge.main_llm_judge
"""

import ray
import hydra
from omegaconf import DictConfig
from skyrl_train.utils import initialize_ray
from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg
from skyrl_gym.envs import register


@ray.remote(num_cpus=1)
def skyrl_entrypoint(cfg: DictConfig):
# Register the multiply environment inside the entrypoint task (no need to modify the skyrl-gym package).
register(
id="llm_as_a_judge",
entry_point="examples.llm_as_a_judge.llm_judge_env:GSM8kLLMJudgeEnv",
)

# make sure that the training loop is not run on the head node.
exp = BasePPOExp(cfg)
exp.run()


@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None)
def main(cfg: DictConfig) -> None:
# validate the arguments
validate_cfg(cfg)

initialize_ray(cfg)
ray.get(skyrl_entrypoint.remote(cfg))


if __name__ == "__main__":
main()
57 changes: 57 additions & 0 deletions skyrl-train/examples/llm_as_a_judge/run_llm_judge.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
set -x

# Colocated GRPO training+generation for Qwen2.5-Coder-7B-Instruct on SkyRL-SQL-653 data.
# Uses 1 node with 8 GPUs.
# uv run examples/llm_as_a_judge/gsm8k_dataset_judge_env.py --output_dir $HOME/data/gsm8k_llm_judge
# add OPENAI_API_KEY and WANDB_API_KEY to .env.llm_judge
# bash examples/llm_as_a_judge/run_llm_judge.sh

DATA_DIR="$HOME/data/gsm8k_llm_judge"
CKPT_PATH="$HOME/ckpts/llm_judge"

NUM_GPUS=4
NUM_INFERENCE_ENGINES=4
TP_SIZE=1
LOGGER=wandb

# We use a smaller batch size here for demonstration
uv run --isolated --extra vllm --env-file .env.llm_judge -m examples.llm_as_a_judge.main_llm_judge \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.num_inference_engines=$NUM_INFERENCE_ENGINES \
generator.inference_engine_tensor_parallel_size=$TP_SIZE \
trainer.epochs=20 \
trainer.eval_batch_size=32 \
trainer.eval_before_train=false \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=32 \
trainer.policy_mini_batch_size=32 \
trainer.micro_forward_batch_size_per_gpu=40 \
trainer.micro_train_batch_size_per_gpu=40 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=true \
generator.backend=vllm \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
generator.batched=true \
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k" \
trainer.run_name="gsm8k_llm_as_a_judge" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \
environment.env_class=llm_as_a_judge \
environment.skyrl_gym.llm_as_a_judge.model="gpt-4o-mini" \
$@
8 changes: 7 additions & 1 deletion skyrl-train/skyrl_train/config/skyrl_gym_config/default.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# @package environment.skyrl_gym
# number of background workers for env step calls. Set to 0 to disable background workers.
max_env_workers: 32

text2sql:
db_path: "/home/ray/default/sql_data"

llm_as_a_judge:
model: "gpt-4o-mini"
base_url: null # or a local endpoint: http://localhost:8000/v1

search:
log_requests: false
search_url: "http://127.0.0.1:8000/retrieve"
topk: 3
timeout: 30
timeout: 30