3
3
import argparse
4
4
import asyncio
5
5
import concurrent .futures
6
- import os
7
- from typing import Any , Dict , List
6
+ import random
7
+ from typing import List
8
8
from dotenv import load_dotenv
9
9
10
10
import art
19
19
from tau_bench .agents .tool_calling_agent import ToolCallingRLAgent
20
20
from tau_bench .types import TauBenchPolicyConfig , TauBenchTrainingConfig
21
21
from tau_bench .general_rm import create_general_rm_trajectory_groups
22
- from langfuse import Langfuse
22
+ from tau_bench . rl_utils import log_trajectory_to_openpipe , update_steps_for_openpipe_logs
23
23
from tqdm .asyncio import tqdm_asyncio
24
24
25
25
# Load environment variables
26
26
load_dotenv (override = True )
27
27
28
- def log_trajectory_to_langfuse (
29
- traj : art .Trajectory ,
30
- messages : List [Dict [str , Any ]]
31
- ) -> None :
32
- """
33
- Push one trajectory to Langfuse with task_idx and step for comparison.
34
- """
35
- # Initialize langfuse
36
- langfuse = Langfuse (
37
- secret_key = os .getenv ("LANGFUSE_SECRET_KEY" ),
38
- public_key = os .getenv ("LANGFUSE_PUBLIC_KEY" ),
39
- host = os .getenv ("LANGFUSE_HOST" ),
40
- )
41
- phase = traj .metadata .get ("phase" , "unknown" )
42
- step = traj .metadata .get ("training_step" , 0 )
43
- task_idx = traj .metadata .get ("task_index" , 0 )
44
- env = traj .metadata .get ("env" , "unknown" )
45
-
46
- trace_name = f"rl-{ phase } -step-{ step } -task-{ task_idx } "
47
-
48
- # Create trace with trajectory data
49
- trace = langfuse .trace (
50
- name = trace_name ,
51
- input = {
52
- "task_idx" : task_idx ,
53
- "step" : step ,
54
- "phase" : phase ,
55
- "metadata" : traj .metadata
56
- },
57
- output = {
58
- "messages" : messages ,
59
- "reward" : traj .reward ,
60
- "metadata" : traj .metadata
61
- },
62
- metadata = {
63
- "task_idx" : task_idx ,
64
- "training_step" : step ,
65
- "phase" : phase ,
66
- "env" : env
67
- }
68
- )
69
-
70
- # Add reward as a score
71
- trace .score (name = "reward" , value = traj .reward )
72
-
73
28
async def rollout_tau_bench_task (
74
29
model : art .Model [TauBenchPolicyConfig ],
75
30
task_index : int ,
@@ -109,10 +64,12 @@ async def rollout_tau_bench_task(
109
64
messages_and_choices = [],
110
65
reward = 0 ,
111
66
metadata = {
112
- "task_index" : task_index ,
67
+ "task_index" : str ( task_index ) ,
113
68
"env" : config .env ,
114
- "training_step" : step ,
115
- "phase" : phase
69
+ "training_step" : str (step ),
70
+ "phase" : phase ,
71
+ "model" : model .name ,
72
+ "reward_type" : config .reward_type ,
116
73
}
117
74
)
118
75
@@ -126,14 +83,18 @@ async def rollout_tau_bench_task(
126
83
127
84
# Convert result to trajectory format
128
85
traj .reward = result .reward
129
- traj .metadata .update (result .info )
130
86
traj .metrics = {
131
87
"total_steps" : result .info ["total_steps" ],
132
88
"final_prompt_tokens" : result .info ["final_prompt_tokens" ],
133
89
"avg_completion_tokens" : result .info ["avg_completion_tokens" ],
134
90
"max_completion_tokens" : result .info ["max_completion_tokens" ],
91
+ "outcome_correct" : traj .reward ,
135
92
}
136
-
93
+ traj .metadata .update (result .info )
94
+ traj .metadata ["reward" ] = "pending_general_rm" if config .reward_type == "general_rm" else traj .reward
95
+ traj .metadata ["outcome_correct" ] = traj .reward
96
+
97
+
137
98
traj .messages_and_choices = agent .create_messages_and_choices (result .messages )
138
99
except Exception as e :
139
100
print (f"Error in rollout for task { task_index } : { e } " )
@@ -142,11 +103,11 @@ async def rollout_tau_bench_task(
142
103
143
104
traj .finish ()
144
105
145
- # Log to langfuse
106
+ # Log to langfuse/openpipe
146
107
try :
147
- log_trajectory_to_langfuse (traj , result .messages )
108
+ await log_trajectory_to_openpipe (traj , result .messages )
148
109
except Exception as e :
149
- print (f"Error logging trajectory to langfuse : { e } " )
110
+ print (f"Error logging trajectory to openpipe : { e } " )
150
111
151
112
# print(f"Finished rolling out task {task_index} (reward: {traj.reward})")
152
113
return traj
@@ -239,6 +200,8 @@ def parse_args() -> tuple[RunConfig, TauBenchTrainingConfig, argparse.Namespace]
239
200
parser .add_argument ("--reward-type" , type = str , default = "real" , help = "Reward type" )
240
201
parser .add_argument ("--general-rm-model" , type = str , default = "o3" , help = "Model to use for general RM. ignored if reward type is not general_rm" )
241
202
parser .add_argument ("--max-num-steps" , type = int , default = 30 , help = "Maximum number of steps per rollout" )
203
+ parser .add_argument ("--train-mode" , type = str , default = "sync_rl" , choices = ["sync_rl" , "async_rl" ], help = "Training mode" )
204
+ parser .add_argument ("--skip-eval" , action = "store_true" , default = False , help = "Skip evaluation" )
242
205
243
206
args = parser .parse_args ()
244
207
print (args )
@@ -258,14 +221,15 @@ def parse_args() -> tuple[RunConfig, TauBenchTrainingConfig, argparse.Namespace]
258
221
end_index = args .end_index ,
259
222
task_ids = args .task_ids ,
260
223
log_dir = args .log_dir ,
261
- max_concurrency = 1 , # RL training is sequential
224
+ max_concurrency = 50 ,
262
225
seed = args .seed ,
263
226
shuffle = args .shuffle ,
264
227
user_strategy = args .user_strategy ,
265
228
few_shot_displays_path = args .few_shot_displays_path ,
266
229
reward_type = args .reward_type ,
267
230
general_rm_model = args .general_rm_model ,
268
- max_num_steps = args .max_num_steps
231
+ max_num_steps = args .max_num_steps ,
232
+ skip_eval = args .skip_eval ,
269
233
)
270
234
271
235
# Create training config
@@ -277,6 +241,7 @@ def parse_args() -> tuple[RunConfig, TauBenchTrainingConfig, argparse.Namespace]
277
241
val_set_size = args .val_set_size ,
278
242
training_dataset_size = args .training_dataset_size ,
279
243
num_epochs = args .num_epochs ,
244
+ train_mode = args .train_mode ,
280
245
)
281
246
282
247
return run_config , training_config , args
@@ -286,27 +251,17 @@ async def evaluate_model(
286
251
model : art .TrainableModel [TauBenchPolicyConfig ],
287
252
config : RunConfig ,
288
253
step : int ,
289
- num_eval_tasks : int = 50
254
+ val_task_indices : List [ int ]
290
255
) -> float :
291
256
"""Evaluate the model on a subset of tasks"""
292
- print (f"Evaluating model on { num_eval_tasks } tasks..." )
293
-
294
- # Get environment for evaluation
295
- env = get_env (
296
- config .env ,
297
- user_strategy = config .user_strategy ,
298
- user_model = config .user_model ,
299
- user_provider = config .user_model_provider ,
300
- task_split = config .task_split ,
301
- )
257
+ print (f"Evaluating model on { len (val_task_indices )} tasks..." )
302
258
303
259
total_reward = 0.0
304
- eval_tasks = min (num_eval_tasks , len (env .tasks ))
305
260
306
261
trajectories = await art .gather_trajectories (
307
262
(
308
- async_rollout_tau_bench_task (model , i , step , "val" )
309
- for i in range ( eval_tasks )
263
+ async_rollout_tau_bench_task (model , val_task_index , step , "val" )
264
+ for val_task_index in val_task_indices
310
265
)
311
266
)
312
267
await model .log (trajectories = trajectories , split = "val" )
@@ -315,7 +270,7 @@ async def evaluate_model(
315
270
total_reward += traj .reward
316
271
print (f"Eval task { traj .metadata ['task_index' ]} : reward={ traj .reward } " )
317
272
318
- avg_reward = total_reward / eval_tasks
273
+ avg_reward = total_reward / len ( val_task_indices )
319
274
print (f"Average evaluation reward: { avg_reward } " )
320
275
return avg_reward
321
276
@@ -360,71 +315,123 @@ async def train(model: art.TrainableModel[TauBenchPolicyConfig]):
360
315
361
316
print (f"Training on { len (train_task_indices )} tasks" )
362
317
print (f"Validation on { len (val_task_indices )} tasks" )
363
-
364
- # Training iterator
365
- train_iterator = iterate_dataset (
366
- train_task_indices ,
367
- groups_per_step = training_config .groups_per_step ,
368
- num_epochs = training_config .num_epochs ,
369
- initial_step = await model .get_step (),
370
- )
371
-
372
- for batch , epoch , global_step , epoch_step in train_iterator :
373
- print (f"\n --- Training Step { global_step } (Epoch { epoch } , Step { epoch_step } ) ---" )
374
-
375
- # Evaluation
376
- if global_step % training_config .eval_steps == 0 :
377
- print (f"\n --- Evaluating at Step { global_step } ---" )
378
- await evaluate_model (model , config , global_step , num_eval_tasks = len (val_task_indices ))
379
- await model .delete_checkpoints ()
380
-
381
- # Generate trajectory groups
382
- print (f"Generating trajectories for { len (batch )} tasks..." )
383
- groups = await art .gather_trajectory_groups (
318
+
319
+ if training_config .train_mode == "async_rl" :
320
+ global_step = 0
321
+ train_task_indices_async_rl = []
322
+ for _ in range (training_config .num_epochs ):
323
+ train_task_indices_async_rl .extend (random .sample (train_task_indices , len (train_task_indices )))
324
+
325
+ async for trajectory_groups in art .trajectory_group_batches (
384
326
(
385
327
art .TrajectoryGroup (
386
328
(
387
- async_rollout_tau_bench_task (model , task_index , global_step , "train" )
329
+ async_rollout_tau_bench_task (model , task_index , - 1 , "train" )
388
330
for _ in range (training_config .trajectories_per_group )
389
331
)
390
332
)
391
- for task_index in batch
392
- )
393
- )
394
- if config .reward_type == "general_rm" :
395
- print ("Creating general RM trajectory groups..." )
396
- updated_groups = await tqdm_asyncio .gather (
397
- * [
398
- create_general_rm_trajectory_groups (group , config )
399
- for group in groups
400
- ],
401
- desc = "Creating general RM trajectory groups" ,
402
- total = len (groups ),
403
- )
404
- groups = updated_groups
405
-
406
- # Training step
407
- print (f"Training on { len (groups )} trajectory groups..." )
408
- await model .train (
409
- groups ,
410
- config = art .TrainConfig (
411
- learning_rate = training_config .learning_rate
333
+ for task_index in train_task_indices_async_rl
412
334
),
335
+ batch_size = training_config .groups_per_step ,
336
+ max_concurrent_batches = 3 ,
337
+ skip_batches = await model .get_step (),
338
+ ):
339
+ if global_step % training_config .eval_steps == 0 and not config .skip_eval :
340
+ print (f"\n --- Evaluating at Step { global_step } ---" )
341
+ await evaluate_model (model , config , global_step , val_task_indices )
342
+ # await model.delete_checkpoints()
343
+
344
+ if config .reward_type == "general_rm" :
345
+ print ("Creating general RM trajectory groups..." )
346
+ updated_groups = await tqdm_asyncio .gather (
347
+ * [
348
+ create_general_rm_trajectory_groups (group , config )
349
+ for group in trajectory_groups
350
+ ],
351
+ desc = "Creating general RM trajectory groups" ,
352
+ total = len (trajectory_groups ),
353
+ )
354
+ trajectory_groups = updated_groups
355
+
356
+ try :
357
+ await update_steps_for_openpipe_logs (trajectory_groups , global_step )
358
+ except Exception as e :
359
+ print (f"Error updating steps for openpipe logs: { e } " )
360
+
361
+ # Training step
362
+ print (f"Training on { len (trajectory_groups )} trajectory groups..." )
363
+ await model .train (
364
+ trajectory_groups ,
365
+ config = art .TrainConfig (
366
+ learning_rate = training_config .learning_rate
367
+ ),
368
+ )
369
+ global_step += 1
370
+ else :
371
+ # Training iterator
372
+ train_iterator = iterate_dataset (
373
+ train_task_indices ,
374
+ groups_per_step = training_config .groups_per_step ,
375
+ num_epochs = training_config .num_epochs ,
376
+ initial_step = await model .get_step (),
413
377
)
414
378
415
- # Log progress
416
- total_reward = sum (
417
- sum (traj .reward for traj in group .trajectories )
418
- for group in groups
419
- )
420
- num_trajectories = sum (len (group .trajectories ) for group in groups )
421
- avg_reward = total_reward / num_trajectories if num_trajectories > 0 else 0
422
- print (f"Step { global_step } : Average training reward = { avg_reward } " )
379
+ for batch , epoch , global_step , epoch_step in train_iterator :
380
+ print (f"\n --- Training Step { global_step } (Epoch { epoch } , Step { epoch_step } ) ---" )
381
+
382
+ # Evaluation
383
+ if global_step % training_config .eval_steps == 0 and not config .skip_eval :
384
+ print (f"\n --- Evaluating at Step { global_step } ---" )
385
+ await evaluate_model (model , config , global_step , val_task_indices )
386
+ await model .delete_checkpoints ()
387
+
388
+ # Generate trajectory groups
389
+ print (f"Generating trajectories for { len (batch )} tasks..." )
390
+ groups = await art .gather_trajectory_groups (
391
+ (
392
+ art .TrajectoryGroup (
393
+ (
394
+ async_rollout_tau_bench_task (model , task_index , global_step , "train" )
395
+ for _ in range (training_config .trajectories_per_group )
396
+ )
397
+ )
398
+ for task_index in batch
399
+ )
400
+ )
401
+ if config .reward_type == "general_rm" :
402
+ print ("Creating general RM trajectory groups..." )
403
+ updated_groups = await tqdm_asyncio .gather (
404
+ * [
405
+ create_general_rm_trajectory_groups (group , config )
406
+ for group in groups
407
+ ],
408
+ desc = "Creating general RM trajectory groups" ,
409
+ total = len (groups ),
410
+ )
411
+ groups = updated_groups
412
+
413
+ # Training step
414
+ print (f"Training on { len (groups )} trajectory groups..." )
415
+ await model .train (
416
+ groups ,
417
+ config = art .TrainConfig (
418
+ learning_rate = training_config .learning_rate
419
+ ),
420
+ )
421
+
422
+ # Log progress
423
+ total_reward = sum (
424
+ sum (traj .reward for traj in group .trajectories )
425
+ for group in groups
426
+ )
427
+ num_trajectories = sum (len (group .trajectories ) for group in groups )
428
+ avg_reward = total_reward / num_trajectories if num_trajectories > 0 else 0
429
+ print (f"Step { global_step } : Average training reward = { avg_reward } " )
423
430
424
431
# Final evaluation
425
432
print ("\n --- Final Evaluation ---" )
426
433
final_step = await model .get_step ()
427
- final_reward = await evaluate_model (model , config , final_step , num_eval_tasks = len ( val_task_indices ) )
434
+ final_reward = await evaluate_model (model , config , final_step , val_task_indices )
428
435
print (f"Final average reward: { final_reward } " )
429
436
430
437
print ("Training completed!" )
0 commit comments