Skip to content

Commit b1bd80c

Browse files
authored
Merge pull request #179 from OpenPipe/tau_bench_async_rl
Tau bench async rl
2 parents f0f4c2c + ea5212d commit b1bd80c

File tree

9 files changed

+406
-175
lines changed

9 files changed

+406
-175
lines changed

dev/tau-bench/pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@ name = "tau-bench"
33
version = "0.1.0"
44
requires-python = ">=3.11"
55
dependencies = [
6-
"anthropic>=0.54.0",
76
"google-generativeai>=0.8.5",
87
"langfuse>=2.60.8",
98
"litellm>=1.72.6.post2",
109
"mistralai>=1.8.2",
11-
"openai>=1.88.0",
10+
"openpipe>=4.50.0",
1211
"openpipe-art",
1312
"skypilot[runpod]>=0.9.3",
1413
"tenacity>=9.1.2",
1514
"termcolor>=3.1.0",
15+
"openai>=1.74.0",
16+
"anthropic>=0.49.0",
1617
]
1718

1819
[tool.uv.sources]

dev/tau-bench/run_rl.py

Lines changed: 132 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import argparse
44
import asyncio
55
import concurrent.futures
6-
import os
7-
from typing import Any, Dict, List
6+
import random
7+
from typing import List
88
from dotenv import load_dotenv
99

1010
import art
@@ -19,57 +19,12 @@
1919
from tau_bench.agents.tool_calling_agent import ToolCallingRLAgent
2020
from tau_bench.types import TauBenchPolicyConfig, TauBenchTrainingConfig
2121
from tau_bench.general_rm import create_general_rm_trajectory_groups
22-
from langfuse import Langfuse
22+
from tau_bench.rl_utils import log_trajectory_to_openpipe, update_steps_for_openpipe_logs
2323
from tqdm.asyncio import tqdm_asyncio
2424

2525
# Load environment variables
2626
load_dotenv(override=True)
2727

28-
def log_trajectory_to_langfuse(
29-
traj: art.Trajectory,
30-
messages: List[Dict[str, Any]]
31-
) -> None:
32-
"""
33-
Push one trajectory to Langfuse with task_idx and step for comparison.
34-
"""
35-
# Initialize langfuse
36-
langfuse = Langfuse(
37-
secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
38-
public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
39-
host=os.getenv("LANGFUSE_HOST"),
40-
)
41-
phase = traj.metadata.get("phase", "unknown")
42-
step = traj.metadata.get("training_step", 0)
43-
task_idx = traj.metadata.get("task_index", 0)
44-
env = traj.metadata.get("env", "unknown")
45-
46-
trace_name = f"rl-{phase}-step-{step}-task-{task_idx}"
47-
48-
# Create trace with trajectory data
49-
trace = langfuse.trace(
50-
name=trace_name,
51-
input={
52-
"task_idx": task_idx,
53-
"step": step,
54-
"phase": phase,
55-
"metadata": traj.metadata
56-
},
57-
output={
58-
"messages": messages,
59-
"reward": traj.reward,
60-
"metadata": traj.metadata
61-
},
62-
metadata={
63-
"task_idx": task_idx,
64-
"training_step": step,
65-
"phase": phase,
66-
"env": env
67-
}
68-
)
69-
70-
# Add reward as a score
71-
trace.score(name="reward", value=traj.reward)
72-
7328
async def rollout_tau_bench_task(
7429
model: art.Model[TauBenchPolicyConfig],
7530
task_index: int,
@@ -109,10 +64,12 @@ async def rollout_tau_bench_task(
10964
messages_and_choices=[],
11065
reward=0,
11166
metadata={
112-
"task_index": task_index,
67+
"task_index": str(task_index),
11368
"env": config.env,
114-
"training_step": step,
115-
"phase": phase
69+
"training_step": str(step),
70+
"phase": phase,
71+
"model": model.name,
72+
"reward_type": config.reward_type,
11673
}
11774
)
11875

@@ -126,14 +83,18 @@ async def rollout_tau_bench_task(
12683

12784
# Convert result to trajectory format
12885
traj.reward = result.reward
129-
traj.metadata.update(result.info)
13086
traj.metrics = {
13187
"total_steps": result.info["total_steps"],
13288
"final_prompt_tokens": result.info["final_prompt_tokens"],
13389
"avg_completion_tokens": result.info["avg_completion_tokens"],
13490
"max_completion_tokens": result.info["max_completion_tokens"],
91+
"outcome_correct": traj.reward,
13592
}
136-
93+
traj.metadata.update(result.info)
94+
traj.metadata["reward"] = "pending_general_rm" if config.reward_type == "general_rm" else traj.reward
95+
traj.metadata["outcome_correct"] = traj.reward
96+
97+
13798
traj.messages_and_choices = agent.create_messages_and_choices(result.messages)
13899
except Exception as e:
139100
print(f"Error in rollout for task {task_index}: {e}")
@@ -142,11 +103,11 @@ async def rollout_tau_bench_task(
142103

143104
traj.finish()
144105

145-
# Log to langfuse
106+
# Log to langfuse/openpipe
146107
try:
147-
log_trajectory_to_langfuse(traj, result.messages)
108+
await log_trajectory_to_openpipe(traj, result.messages)
148109
except Exception as e:
149-
print(f"Error logging trajectory to langfuse: {e}")
110+
print(f"Error logging trajectory to openpipe: {e}")
150111

151112
# print(f"Finished rolling out task {task_index} (reward: {traj.reward})")
152113
return traj
@@ -239,6 +200,8 @@ def parse_args() -> tuple[RunConfig, TauBenchTrainingConfig, argparse.Namespace]
239200
parser.add_argument("--reward-type", type=str, default="real", help="Reward type")
240201
parser.add_argument("--general-rm-model", type=str, default="o3", help="Model to use for general RM. ignored if reward type is not general_rm")
241202
parser.add_argument("--max-num-steps", type=int, default=30, help="Maximum number of steps per rollout")
203+
parser.add_argument("--train-mode", type=str, default="sync_rl", choices=["sync_rl", "async_rl"], help="Training mode")
204+
parser.add_argument("--skip-eval", action="store_true", default=False, help="Skip evaluation")
242205

243206
args = parser.parse_args()
244207
print(args)
@@ -258,14 +221,15 @@ def parse_args() -> tuple[RunConfig, TauBenchTrainingConfig, argparse.Namespace]
258221
end_index=args.end_index,
259222
task_ids=args.task_ids,
260223
log_dir=args.log_dir,
261-
max_concurrency=1, # RL training is sequential
224+
max_concurrency=50,
262225
seed=args.seed,
263226
shuffle=args.shuffle,
264227
user_strategy=args.user_strategy,
265228
few_shot_displays_path=args.few_shot_displays_path,
266229
reward_type=args.reward_type,
267230
general_rm_model=args.general_rm_model,
268-
max_num_steps=args.max_num_steps
231+
max_num_steps=args.max_num_steps,
232+
skip_eval=args.skip_eval,
269233
)
270234

271235
# Create training config
@@ -277,6 +241,7 @@ def parse_args() -> tuple[RunConfig, TauBenchTrainingConfig, argparse.Namespace]
277241
val_set_size=args.val_set_size,
278242
training_dataset_size=args.training_dataset_size,
279243
num_epochs=args.num_epochs,
244+
train_mode=args.train_mode,
280245
)
281246

282247
return run_config, training_config, args
@@ -286,27 +251,17 @@ async def evaluate_model(
286251
model: art.TrainableModel[TauBenchPolicyConfig],
287252
config: RunConfig,
288253
step: int,
289-
num_eval_tasks: int = 50
254+
val_task_indices: List[int]
290255
) -> float:
291256
"""Evaluate the model on a subset of tasks"""
292-
print(f"Evaluating model on {num_eval_tasks} tasks...")
293-
294-
# Get environment for evaluation
295-
env = get_env(
296-
config.env,
297-
user_strategy=config.user_strategy,
298-
user_model=config.user_model,
299-
user_provider=config.user_model_provider,
300-
task_split=config.task_split,
301-
)
257+
print(f"Evaluating model on {len(val_task_indices)} tasks...")
302258

303259
total_reward = 0.0
304-
eval_tasks = min(num_eval_tasks, len(env.tasks))
305260

306261
trajectories = await art.gather_trajectories(
307262
(
308-
async_rollout_tau_bench_task(model, i, step, "val")
309-
for i in range(eval_tasks)
263+
async_rollout_tau_bench_task(model, val_task_index, step, "val")
264+
for val_task_index in val_task_indices
310265
)
311266
)
312267
await model.log(trajectories=trajectories, split="val")
@@ -315,7 +270,7 @@ async def evaluate_model(
315270
total_reward += traj.reward
316271
print(f"Eval task {traj.metadata['task_index']}: reward={traj.reward}")
317272

318-
avg_reward = total_reward / eval_tasks
273+
avg_reward = total_reward / len(val_task_indices)
319274
print(f"Average evaluation reward: {avg_reward}")
320275
return avg_reward
321276

@@ -360,71 +315,123 @@ async def train(model: art.TrainableModel[TauBenchPolicyConfig]):
360315

361316
print(f"Training on {len(train_task_indices)} tasks")
362317
print(f"Validation on {len(val_task_indices)} tasks")
363-
364-
# Training iterator
365-
train_iterator = iterate_dataset(
366-
train_task_indices,
367-
groups_per_step=training_config.groups_per_step,
368-
num_epochs=training_config.num_epochs,
369-
initial_step=await model.get_step(),
370-
)
371-
372-
for batch, epoch, global_step, epoch_step in train_iterator:
373-
print(f"\n--- Training Step {global_step} (Epoch {epoch}, Step {epoch_step}) ---")
374-
375-
# Evaluation
376-
if global_step % training_config.eval_steps == 0:
377-
print(f"\n--- Evaluating at Step {global_step} ---")
378-
await evaluate_model(model, config, global_step, num_eval_tasks=len(val_task_indices))
379-
await model.delete_checkpoints()
380-
381-
# Generate trajectory groups
382-
print(f"Generating trajectories for {len(batch)} tasks...")
383-
groups = await art.gather_trajectory_groups(
318+
319+
if training_config.train_mode == "async_rl":
320+
global_step = 0
321+
train_task_indices_async_rl = []
322+
for _ in range(training_config.num_epochs):
323+
train_task_indices_async_rl.extend(random.sample(train_task_indices, len(train_task_indices)))
324+
325+
async for trajectory_groups in art.trajectory_group_batches(
384326
(
385327
art.TrajectoryGroup(
386328
(
387-
async_rollout_tau_bench_task(model, task_index, global_step, "train")
329+
async_rollout_tau_bench_task(model, task_index, -1, "train")
388330
for _ in range(training_config.trajectories_per_group)
389331
)
390332
)
391-
for task_index in batch
392-
)
393-
)
394-
if config.reward_type == "general_rm":
395-
print("Creating general RM trajectory groups...")
396-
updated_groups = await tqdm_asyncio.gather(
397-
*[
398-
create_general_rm_trajectory_groups(group, config)
399-
for group in groups
400-
],
401-
desc="Creating general RM trajectory groups",
402-
total=len(groups),
403-
)
404-
groups = updated_groups
405-
406-
# Training step
407-
print(f"Training on {len(groups)} trajectory groups...")
408-
await model.train(
409-
groups,
410-
config=art.TrainConfig(
411-
learning_rate=training_config.learning_rate
333+
for task_index in train_task_indices_async_rl
412334
),
335+
batch_size=training_config.groups_per_step,
336+
max_concurrent_batches=3,
337+
skip_batches=await model.get_step(),
338+
):
339+
if global_step % training_config.eval_steps == 0 and not config.skip_eval:
340+
print(f"\n--- Evaluating at Step {global_step} ---")
341+
await evaluate_model(model, config, global_step, val_task_indices)
342+
# await model.delete_checkpoints()
343+
344+
if config.reward_type == "general_rm":
345+
print("Creating general RM trajectory groups...")
346+
updated_groups = await tqdm_asyncio.gather(
347+
*[
348+
create_general_rm_trajectory_groups(group, config)
349+
for group in trajectory_groups
350+
],
351+
desc="Creating general RM trajectory groups",
352+
total=len(trajectory_groups),
353+
)
354+
trajectory_groups = updated_groups
355+
356+
try:
357+
await update_steps_for_openpipe_logs(trajectory_groups, global_step)
358+
except Exception as e:
359+
print(f"Error updating steps for openpipe logs: {e}")
360+
361+
# Training step
362+
print(f"Training on {len(trajectory_groups)} trajectory groups...")
363+
await model.train(
364+
trajectory_groups,
365+
config=art.TrainConfig(
366+
learning_rate=training_config.learning_rate
367+
),
368+
)
369+
global_step += 1
370+
else:
371+
# Training iterator
372+
train_iterator = iterate_dataset(
373+
train_task_indices,
374+
groups_per_step=training_config.groups_per_step,
375+
num_epochs=training_config.num_epochs,
376+
initial_step=await model.get_step(),
413377
)
414378

415-
# Log progress
416-
total_reward = sum(
417-
sum(traj.reward for traj in group.trajectories)
418-
for group in groups
419-
)
420-
num_trajectories = sum(len(group.trajectories) for group in groups)
421-
avg_reward = total_reward / num_trajectories if num_trajectories > 0 else 0
422-
print(f"Step {global_step}: Average training reward = {avg_reward}")
379+
for batch, epoch, global_step, epoch_step in train_iterator:
380+
print(f"\n--- Training Step {global_step} (Epoch {epoch}, Step {epoch_step}) ---")
381+
382+
# Evaluation
383+
if global_step % training_config.eval_steps == 0 and not config.skip_eval:
384+
print(f"\n--- Evaluating at Step {global_step} ---")
385+
await evaluate_model(model, config, global_step, val_task_indices)
386+
await model.delete_checkpoints()
387+
388+
# Generate trajectory groups
389+
print(f"Generating trajectories for {len(batch)} tasks...")
390+
groups = await art.gather_trajectory_groups(
391+
(
392+
art.TrajectoryGroup(
393+
(
394+
async_rollout_tau_bench_task(model, task_index, global_step, "train")
395+
for _ in range(training_config.trajectories_per_group)
396+
)
397+
)
398+
for task_index in batch
399+
)
400+
)
401+
if config.reward_type == "general_rm":
402+
print("Creating general RM trajectory groups...")
403+
updated_groups = await tqdm_asyncio.gather(
404+
*[
405+
create_general_rm_trajectory_groups(group, config)
406+
for group in groups
407+
],
408+
desc="Creating general RM trajectory groups",
409+
total=len(groups),
410+
)
411+
groups = updated_groups
412+
413+
# Training step
414+
print(f"Training on {len(groups)} trajectory groups...")
415+
await model.train(
416+
groups,
417+
config=art.TrainConfig(
418+
learning_rate=training_config.learning_rate
419+
),
420+
)
421+
422+
# Log progress
423+
total_reward = sum(
424+
sum(traj.reward for traj in group.trajectories)
425+
for group in groups
426+
)
427+
num_trajectories = sum(len(group.trajectories) for group in groups)
428+
avg_reward = total_reward / num_trajectories if num_trajectories > 0 else 0
429+
print(f"Step {global_step}: Average training reward = {avg_reward}")
423430

424431
# Final evaluation
425432
print("\n--- Final Evaluation ---")
426433
final_step = await model.get_step()
427-
final_reward = await evaluate_model(model, config, final_step, num_eval_tasks=len(val_task_indices))
434+
final_reward = await evaluate_model(model, config, final_step, val_task_indices)
428435
print(f"Final average reward: {final_reward}")
429436

430437
print("Training completed!")

0 commit comments

Comments
 (0)