Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions olmocr/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ async def worker(args, work_queue: WorkQueue, semaphore, worker_id):
semaphore.release()


async def vllm_server_task(model_name_or_path, args, semaphore):
async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=None):
cmd = [
"vllm",
"serve",
Expand All @@ -594,6 +594,9 @@ async def vllm_server_task(model_name_or_path, args, semaphore):
if args.max_model_len is not None:
cmd.extend(["--max-model-len", str(args.max_model_len)])

if unknown_args:
cmd.extend(unknown_args)

proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
Expand Down Expand Up @@ -683,12 +686,12 @@ async def timeout_task():
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)


async def vllm_server_host(model_name_or_path, args, semaphore):
async def vllm_server_host(model_name_or_path, args, semaphore, unknown_args=None):
MAX_RETRIES = 5
retry = 0

while retry < MAX_RETRIES:
await vllm_server_task(model_name_or_path, args, semaphore)
await vllm_server_task(model_name_or_path, args, semaphore, unknown_args)
logger.warning("VLLM server task ended")
retry += 1

Expand Down Expand Up @@ -998,7 +1001,7 @@ def process_output_file(s3_path):


async def main():
parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline")
parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline.")
parser.add_argument(
"workspace",
help="The filesystem path where work will be stored, can be a local folder, or an s3 path if coordinating work with many workers, s3://bucket/prefix/ ",
Expand Down Expand Up @@ -1030,7 +1033,10 @@ async def main():
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1)
parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs")

vllm_group = parser.add_argument_group("VLLM Forwarded arguments")
vllm_group = parser.add_argument_group(
"VLLM arguments",
"These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM."
)
vllm_group.add_argument(
"--gpu-memory-utilization", type=float, help="Fraction of VRAM vLLM may pre-allocate for KV-cache " "(passed through to vllm serve)."
)
Expand All @@ -1051,7 +1057,7 @@ async def main():
beaker_group.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
beaker_group.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")

args = parser.parse_args()
args, unknown_args = parser.parse_known_args()

logger.info(
"If you run out of GPU memory during start-up or get 'KV cache is larger than available memory' errors, retry with lower values, e.g. --gpu_memory_utilization 0.80 --max_model_len 16384"
Expand Down Expand Up @@ -1196,7 +1202,7 @@ async def main():
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1)

vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore))
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args))

await vllm_server_ready()

Expand Down