Skip to content

More flexible TrainableModel #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 15, 2025
Merged

More flexible TrainableModel #51

merged 1 commit into from
Apr 15, 2025

Conversation

corbt
Copy link
Contributor

@corbt corbt commented Apr 15, 2025

This PR makes a few changes to the API shape, specifically focused on Model and LocalAPI.

It introduces an abstraction where we separate out PolicyModel, which could be any LLM, and a TrainablePolicyModel, which is specifically a policy that can be trained by our system. This will let us log trajectories from both PolicyModel and TrainablePolicyModel in a unified way.

It also adds a new config field to PolicyModel. This is opaque to our system, but is something we can log to wandb and our file system in the future to track hparams associated with each model run. I'm using the config field in the following way:

agent_002 = art.TrainablePolicyModel(
    name="email-agent-002",
    project="email_agent",
    base_model="Qwen/Qwen2.5-14B-Instruct",
    config=ProjectPolicyConfig(
        max_turns=10,
        training_config=TrainingConfig(
            trajectories_per_group=6,
            groups_per_step=1,
            learning_rate=1.2e-5,
            eval_steps=30,
            val_set_size=100,
            training_dataset_size=4000,
            batch_size=8,
            num_epochs=4,
        ),
    ),
)

agent_004 = agent_002.model_copy(deep=True)
assert isinstance(agent_004.config, ProjectPolicyConfig)
agent_004.name = "email-agent-004"
agent_004.config.max_turns = 30

And then within my training and rollout functions, adjusting behavior based on the config above. I find this pattern helps me stay sane by tracking what properties each model was called with, and keeping a record around of old training runs (by keeping the old config in the codebase).

@corbt corbt force-pushed the model-methods branch 5 times, most recently from b97b9b3 to 38bab88 Compare April 15, 2025 22:25
@corbt corbt changed the title [WIP]: TrainablePolicyModel More flexible TrainablePolicyModel Apr 15, 2025
@corbt corbt changed the title More flexible TrainablePolicyModel More flexible TrainableModel Apr 15, 2025
@corbt corbt merged commit 095ab48 into main Apr 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant