Skip to content

Commit 38bab88

Browse files
committed
more flexible TrainablePolicyModel
1 parent 069f04d commit 38bab88

File tree

17 files changed

+1311
-1264
lines changed

17 files changed

+1311
-1264
lines changed

dev/yes-no-maybe.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@
4141
"load_dotenv()\n",
4242
"\n",
4343
"\n",
44-
"api = art.LocalAPI()\n",
45-
"model = await api.get_or_create_model(\n",
44+
"model = art.TrainablePolicyModel(\n",
4645
" name=\"001\",\n",
4746
" project=\"yes-no-maybe\",\n",
4847
" base_model=\"Qwen/Qwen2.5-7B-Instruct\",\n",
4948
")\n",
49+
"await model.register_for_training(art.LocalAPI())\n",
5050
"\n",
5151
"\n",
5252
"async def rollout(client: openai.AsyncOpenAI, prompt: str) -> art.Trajectory:\n",
@@ -90,7 +90,7 @@
9090
" ]\n",
9191
"]\n",
9292
"\n",
93-
"openai_client = await model.openai_client()\n",
93+
"openai_client = model.openai_client()\n",
9494
"for _ in range(await model.get_step(), 1_000):\n",
9595
" train_groups = await art.gather_trajectory_groups(\n",
9696
" (\n",

examples/2048/2048-single-turn.ipynb

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@
4747
"random.seed(42)\n",
4848
"\n",
4949
"\n",
50-
"api = art.LocalAPI()\n",
51-
"model = await api.get_or_create_model(\n",
50+
"model = art.TrainablePolicyModel(\n",
5251
" name=\"001\", project=\"2048-single-turn\", base_model=\"Qwen/Qwen2.5-7B-Instruct\"\n",
53-
")"
52+
")\n",
53+
"await model.register_for_training(art.LocalAPI())"
5454
]
5555
},
5656
{
@@ -78,7 +78,6 @@
7878
"\n",
7979
"@art.retry(exceptions=(openai.LengthFinishReasonError, requests.ReadTimeout))\n",
8080
"async def rollout(client: openai.AsyncOpenAI, iteration: int) -> art.Trajectory:\n",
81-
"\n",
8281
" game = generate_game()\n",
8382
"\n",
8483
" reward = 0\n",
@@ -87,7 +86,6 @@
8786
" trajectories: list[art.Trajectory] = []\n",
8887
"\n",
8988
" while True:\n",
90-
"\n",
9189
" trajectory = art.Trajectory(\n",
9290
" messages_and_choices=[\n",
9391
" {\n",
@@ -193,7 +191,7 @@
193191
" return trajectories\n",
194192
"\n",
195193
"\n",
196-
"openai_client = await model.openai_client()\n",
194+
"openai_client = model.openai_client()\n",
197195
"\n",
198196
"for i in range(await model.get_step(), 500):\n",
199197
" train_groups = await art.gather_trajectory_groups(\n",

0 commit comments

Comments
 (0)