-
Notifications
You must be signed in to change notification settings - Fork 326
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
When running LocalBackend on a RunPod instance, I'm seeing a very long (possibly eternal?) hang after a training script theoretically finishes. Sample code below, taken from the tic-tac-toe-local.py example.
DESTROY_AFTER_RUN = False
async def main():
print(0)
# run from the root of the repo
backend = LocalBackend()
print(1)
model = art.TrainableModel(
name="agent-001",
project="tic-tac-toe-agent",
base_model="Qwen/Qwen2.5-3B-Instruct",
)
print(2)
await backend._experimental_pull_from_s3(model)
print(3)
await model.register(backend)
print(4)
step = await model.get_step()
print(f"Step: {step}")
for i in range(await model.get_step(), 101):
train_groups = await art.gather_trajectory_groups(
(
art.TrajectoryGroup(
rollout(model, i, is_validation=False) for _ in range(200)
)
for _ in range(1)
),
pbar_desc="gather",
)
await model.delete_checkpoints()
await model.train(train_groups, config=art.TrainConfig(learning_rate=1e-4))
await backend._experimental_push_to_s3(model)
print(5)
# res = await backend._experimental_deploy(model=model, verbose=True)
# print(res)
print(6)
if DESTROY_AFTER_RUN:
await backend.down()
if __name__ == "__main__":
asyncio.run(main())
End of stdout:
Unsloth: Will smartly offload gradients to save VRAM!
train: 100%|██████████████████████████████████████████| 16/16 [00:18<00:00, 1.15s/it, loss=0.358, grad_norm=0.219, policy_loss=0.358]
5
6
stdout includes the number 6, indicating that everything has finished, but the script hangs anyway.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working