-
Notifications
You must be signed in to change notification settings - Fork 315
Open
Description
Problem
task_init()
on SandboxEnvironment is an asynchronous function, but it gets called in a sequential loop in the actual run.py
In inspect_ai/_eval/run.py
, the startup_sandbox_environments()
function is:
for sandboxenv in sandboxenvs:
await task_init("startup", sandboxenv.sandbox.config) # Blocks!
Even with max_tasks > 1
, each task_init()
waits for its predecessor to complete before starting.
Fix
In inspect_ai/_eval/run.py:517-531
:
Before:
for sandboxenv in sandboxenvs:
# find type
sandboxenv_type = registry_find_sandboxenv(sandboxenv.sandbox.type)
# run startup
task_init = cast(TaskInit, getattr(sandboxenv_type, "task_init"))
with chdir(sandboxenv.run_dir), environ_vars(dict(sandboxenv.env)):
await task_init("startup", sandboxenv.sandbox.config)
# append cleanup method
task_cleanup = cast(TaskCleanup, getattr(sandboxenv_type, "task_cleanup"))
cleanups.append(
(task_cleanup, sandboxenv.sandbox.config, sandboxenv.run_dir)
)
After:
async with anyio.create_task_group() as tg:
for sandboxenv in sandboxenvs:
# find type
sandboxenv_type = registry_find_sandboxenv(sandboxenv.sandbox.type)
# run startup
task_init = cast(TaskInit, getattr(sandboxenv_type, "task_init"))
task_cleanup = cast(TaskCleanup, getattr(sandboxenv_type, "task_cleanup"))
# Closure to capture loop variables for concurrent execution
async def init_env(env=sandboxenv, init_fn=task_init, cleanup_fn=task_cleanup):
with chdir(env.run_dir), environ_vars(dict(env.env)):
await init_fn("startup", env.sandbox.config)
cleanups.append((cleanup_fn, env.sandbox.config, env.run_dir))
tg.start_soon(init_env)
Reproduction
Running this mock Sandbox:
@classmethod
async def task_init(
cls, task_name: str, config: SandboxEnvironmentConfigType | None
) -> None:
"""Simulate slow initialization (like building Docker images or pulling)."""
task_id = config.task_id # type: ignore
log(f"Task {task_id}: Starting task_init")
# Simulate slow setup (docker build, pip install, etc)
await asyncio.sleep(20)
log(f"Task {task_id}: Completed task_init")
With:
eval(
tasks,
max_tasks=2,
max_sandboxes=2,
max_samples=1,
)
We get this output normally:
12:00:00.000 - Task task_1: Starting task_init
12:00:20.000 - Task task_1: Completed task_init
12:00:20.001 - Task task_2: Starting task_init # waits 20s
12:00:40.001 - Task task_2: Completed task_init
With our fix we get:
12:00:00.000 - Task task_1: Starting task_init
12:00:00.001 - Task task_2: Starting task_init # starts immediately
12:00:20.000 - Task task_1: Completed task_init
12:00:20.001 - Task task_2: Completed task_init
Appendix
complete minimal reproduction:
Metadata
Metadata
Assignees
Labels
No labels