Skip to content

allow training without logprobs experimentation #186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions dev/tau-bench/run_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import argparse
import asyncio
import concurrent.futures
import copy
import random
from typing import List
from typing import Any, Dict, List
from dotenv import load_dotenv

import art
from art.local import LocalBackend
from art.utils import iterate_dataset

from tau_bench.types import RunConfig
from tau_bench.types import RunConfig, SolveResult
from tau_bench.envs import get_env
from tau_bench.run import agent_factory
from litellm import provider_list
Expand All @@ -25,19 +26,32 @@
# Load environment variables
load_dotenv(override=True)

def clean_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
cleaned_messages = []
for msg in messages:
cleaned_msg = {k: v for k, v in msg.items() if v is not None}
cleaned_messages.append(cleaned_msg)
return cleaned_messages

async def rollout_tau_bench_task(
model: art.Model[TauBenchPolicyConfig],
task_index: int,
step: int = 0,
phase: str = "train",
is_shadow: bool = False,
) -> art.Trajectory:
"""
Generate a trajectory for a single tau-bench task using the given model.
This adapts the tau-bench evaluation loop for RL trajectory generation.
Now truly async.
"""
# print(f"Rolling out task {task_index} (step {step}, phase {phase})")
config = model.config.run_config
config = copy.deepcopy(model.config.run_config)
if is_shadow:
config.model = "gpt-4.1"
config.model_provider = "openai"
config.api_key = None
config.base_url = None

# Get isolated environment for this task
env = get_env(
Expand Down Expand Up @@ -71,6 +85,7 @@ async def rollout_tau_bench_task(
"phase": phase,
"model": model.name,
"reward_type": config.reward_type,
"is_shadow": str(is_shadow),
}
)

Expand Down Expand Up @@ -98,11 +113,20 @@ async def rollout_tau_bench_task(
traj.metadata["outcome_correct"] = traj.metrics["outcome_correct"]


traj.messages_and_choices = agent.create_messages_and_choices(result.messages)
if config.messages_only:
traj.messages_and_choices = clean_messages(result.messages) # type: ignore
else:
traj.messages_and_choices = agent.create_messages_and_choices(result.messages) # type: ignore
except Exception as e:
print(f"Error in rollout for task {task_index}: {e}")
traj.reward = 0.0
traj.metadata["error"] = str(e)
result = SolveResult(
reward=0.0,
info={"error": str(e)},
messages=[],
total_cost=0.0,
)

traj.finish()

Expand All @@ -122,11 +146,12 @@ async def async_rollout_tau_bench_task(
task_index: int,
step: int = 0,
phase: str = "train",
is_shadow: bool = False,
) -> art.Trajectory:
"""
Direct alias for rollout_tau_bench_task since it's now truly async.
"""
return await rollout_tau_bench_task(model, task_index, step, phase)
return await rollout_tau_bench_task(model, task_index, step, phase, is_shadow)


def parse_args() -> tuple[RunConfig, TauBenchTrainingConfig, argparse.Namespace]:
Expand Down Expand Up @@ -206,6 +231,7 @@ def parse_args() -> tuple[RunConfig, TauBenchTrainingConfig, argparse.Namespace]
parser.add_argument("--train-mode", type=str, default="sync_rl", choices=["sync_rl", "async_rl"], help="Training mode")
parser.add_argument("--skip-eval", action="store_true", default=False, help="Skip evaluation")
parser.add_argument("--add-shadow-trajectory", action="store_true", default=False, help="Add shadow trajectory")
parser.add_argument("--messages-only", action="store_true", default=False, help="Only use messages for training")

args = parser.parse_args()
print(args)
Expand Down Expand Up @@ -235,6 +261,7 @@ def parse_args() -> tuple[RunConfig, TauBenchTrainingConfig, argparse.Namespace]
max_num_steps=args.max_num_steps,
skip_eval=args.skip_eval,
add_shadow_trajectory=args.add_shadow_trajectory,
messages_only=args.messages_only,
)

# Create training config
Expand Down Expand Up @@ -341,6 +368,7 @@ async def train(model: art.TrainableModel[TauBenchPolicyConfig]):
max_concurrent_batches=3,
skip_batches=await model.get_step(),
):
# NOT UPDATED FOR TRAINING WITH SHADOW TRAJECTORIES
if global_step % training_config.eval_steps == 0 and not config.skip_eval:
print(f"\n--- Evaluating at Step {global_step} ---")
await evaluate_model(model, config, global_step, val_task_indices)
Expand Down Expand Up @@ -396,8 +424,8 @@ async def train(model: art.TrainableModel[TauBenchPolicyConfig]):
(
art.TrajectoryGroup(
(
async_rollout_tau_bench_task(model, task_index, global_step, "train")
for _ in range(training_config.trajectories_per_group)
async_rollout_tau_bench_task(model, task_index, global_step, "train", is_shadow=config.add_shadow_trajectory and rollout_idx % training_config.trajectories_per_group == 0)
for rollout_idx in range(training_config.trajectories_per_group)
)
)
for task_index in batch
Expand All @@ -422,6 +450,7 @@ async def train(model: art.TrainableModel[TauBenchPolicyConfig]):
config=art.TrainConfig(
learning_rate=training_config.learning_rate
),
_config=art.dev.TrainConfig(allow_training_without_logprobs=True if config.messages_only else False)
)

# Log progress
Expand Down
17 changes: 16 additions & 1 deletion dev/tau-bench/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"train_mode": "sync_rl",
"skip_eval": False,
"add_shadow_trajectory": False,
"messages_only": False,
}
}

Expand Down Expand Up @@ -152,7 +153,20 @@
models["018"]["trajectories_per_group"] = 10
models["018"]["reward_type"] = "real"


models["019"] = models["001"].copy()
models["019"]["model"] = "tau-bench-rl-019-2"
models["019"]["skip_eval"] = True
models["019"]["training_dataset_size"] = 10
models["019"]["trajectories_per_group"] = 8
models["019"]["groups_per_step"] = 5
models["019"]["num_epochs"] = 150
models["019"]["reward_type"] = "general_rm"
models["019"]["learning_rate"] = 8e-6
models["019"]["messages_only"] = True

models["020"] = models["019"].copy()
models["020"]["model"] = "tau-bench-rl-020-4"
models["020"]["add_shadow_trajectory"] = True

# models["013"] = models["001"].copy()
# models["013"]["model"] = "tau-bench-rl-013"
Expand Down Expand Up @@ -232,6 +246,7 @@ def launch_model(model_key: str):
f"--train-mode {model_config['train_mode']}",
f"{'--skip-eval' if model_config['skip_eval'] else ''}",
f"{'--add-shadow-trajectory' if model_config['add_shadow_trajectory'] else ''}",
f"{'--messages-only' if model_config['messages_only'] else ''}",
]

run_script = textwrap.dedent(f"""
Expand Down
9 changes: 5 additions & 4 deletions dev/tau-bench/tau_bench/agents/tool_calling_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ async def solve(
)
if env_response.done:
forced_stop = False
break
if final_prompt_tokens > 20000 or res.choices[0].finish_reason == "length":
break
info["total_steps"] = curr_step_number + 1
Expand All @@ -104,8 +105,8 @@ async def solve(
class ToolCallingRLAgent(ToolCallingAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = kwargs.get("api_key")
self.base_url = kwargs.get("base_url")
self.api_key = kwargs.get("api_key", None)
self.base_url = kwargs.get("base_url", None)
self.choices = []

async def llm_completion(self, messages: List[Dict[str, Any]]):
Expand All @@ -118,10 +119,10 @@ async def llm_completion(self, messages: List[Dict[str, Any]]):
tools=self.tools_info,
temperature=self.temperature,
max_completion_tokens=1024,
logprobs=True,
logprobs=False if self.provider == "openai" else True,
)
choice = response.choices[0] # type: ignore
assert isinstance(choice, Choices)
assert isinstance(choice, Choices), f"Choice is not a Choices object: {choice}"
self.choices.append(convert_litellm_choice_to_openai(choice))
return response

Expand Down
2 changes: 1 addition & 1 deletion dev/tau-bench/tau_bench/envs/retail/tools/think.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Think(Tool):
@staticmethod
def invoke(data: Dict[str, Any], thought: str) -> str:
# This method does not change the state of the data; it simply returns an empty string.
return ""
return "Thought Completed"

@staticmethod
def get_info() -> Dict[str, Any]:
Expand Down
1 change: 1 addition & 0 deletions dev/tau-bench/tau_bench/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class RunConfig(BaseModel):
max_num_steps: int = 30
skip_eval: bool = False
add_shadow_trajectory: bool = False
messages_only: bool = False
class TauBenchTrainingConfig(BaseModel):
"""Training configuration for ART RL on tau-bench tasks"""
trajectories_per_group: int = 6
Expand Down
9 changes: 9 additions & 0 deletions examples/art-e/all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,12 @@
models[
"0001"
].config.training_config.judge_group_model_name = "openrouter/qwen/qwen3-32b"

models["0002"] = models["008"].model_copy(deep=True)
models["0002"].name = "email-agent-0002"
models["0002"].project = "email_agent_saumya_test"
assert models["0002"].config.training_config is not None
models[
"0002"
].config.training_config.judge_group_model_name = "openrouter/qwen/qwen3-32b"
models["0002"].config.training_config.messages_only = True
2 changes: 1 addition & 1 deletion examples/art-e/art_e/project_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TrainingConfig(BaseModel):
# Random seed to control which subset of the training data is sampled. When None, the sampler can
# choose its own default (e.g., derive from the current time).
training_dataset_seed: int | None = None

messages_only: bool = False

class ProjectPolicyConfig(BaseModel):
max_turns: int = 10
Expand Down
9 changes: 8 additions & 1 deletion examples/art-e/art_e/rollout.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Dict
import art
from art_e.data.types_enron import SyntheticQuery
from art import Trajectory
Expand Down Expand Up @@ -189,6 +190,9 @@ async def judge_correctness(
class ProjectTrajectory(Trajectory):
generated_answer: str | None = None

def clean_message(message: Dict[str, Any]) -> Dict[str, Any]:
return {k: v for k, v in message.items() if v is not None}


@retry(stop=stop_after_attempt(3))
# @weave.op(tracing_sample_rate=0.05) # type: ignore
Expand Down Expand Up @@ -318,7 +322,10 @@ async def return_final_answer(answer: str, sources: list[str]):
# Our rollout is only set up to handle one tool call at a time, so just ignore any parallel tool calls.
if choice.message.tool_calls is not None and len(choice.message.tool_calls) > 1:
choice.message.tool_calls = choice.message.tool_calls[:1]
traj.messages_and_choices.append(convert_litellm_choice_to_openai(choice)) # type: ignore
if model.config.training_config.messages_only:
traj.messages_and_choices.append(clean_message(convert_litellm_choice_to_openai(choice).message.model_dump())) # type: ignore
else:
traj.messages_and_choices.append(convert_litellm_choice_to_openai(choice)) # type: ignore

if choice.message.tool_calls is None:
rubric.bad_tool_call_name = True
Expand Down
1 change: 1 addition & 0 deletions examples/art-e/art_e/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ async def train(model: art.TrainableModel[ProjectPolicyConfig]):
config=art.TrainConfig(
learning_rate=model.config.training_config.learning_rate
),
_config=art.dev.TrainConfig(allow_training_without_logprobs=True if model.config.training_config.messages_only else False)
)

await benchmark_model(model)
Expand Down