Skip to content

Commit 02f0706

Browse files
committed
Reverting back to json pipeline as it seems better by default
1 parent 8ae9104 commit 02f0706

File tree

1 file changed

+11
-26
lines changed

1 file changed

+11
-26
lines changed

olmocr/pipeline.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import boto3
2525
import httpx
26-
import torch
2726
from botocore.exceptions import ClientError
2827
from huggingface_hub import snapshot_download
2928
from PIL import Image
@@ -236,10 +235,6 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
236235
# Change temperature as number of attempts increases to overcome repetition issues at expense of quality
237236
query["temperature"] = TEMPERATURE_BY_ATTEMPT[lookup_attempt]
238237

239-
# Enable guided decoding regex if needed
240-
if args.guided_decoding:
241-
query["guided_regex"] = r"---\nprimary_language: (?:[a-z]{2}|null)\nis_rotation_valid: (?:True|False|true|false)\nrotation_correction: (?:0|90|180|270)\nis_table: (?:True|False|true|false)\nis_diagram: (?:True|False|true|false)\n---\n[\s\S]*"
242-
243238
logger.info(f"Built page query for {pdf_orig_path}-{page_num}")
244239

245240
try:
@@ -258,24 +253,14 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
258253
local_anchor_text_len = max(1, local_anchor_text_len // 2)
259254
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_orig_path}-{page_num}")
260255
raise ValueError("Response exceeded model_max_context, cannot use this response")
261-
262-
if base_response_data["choices"][0]["finish_reason"] != "stop":
263-
local_anchor_text_len = max(1, local_anchor_text_len // 2)
264-
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_orig_path}-{page_num}")
265-
raise ValueError("Response did not finish with reason code 'stop', cannot use this response")
266256

267257
metrics.add_metrics(
268258
server_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
269259
server_output_tokens=base_response_data["usage"].get("completion_tokens", 0),
270260
)
271261

272-
model_response_markdown = base_response_data["choices"][0]["message"]["content"]
273-
274-
# Somewhat temporary code, will need to refactor
275-
from olmocr.train.dataloader import FrontMatterParser
276-
parser = FrontMatterParser(front_matter_class=PageResponse)
277-
front_matter, text = parser._extract_front_matter_and_text(model_response_markdown)
278-
page_response = parser._parse_front_matter(front_matter, text)
262+
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
263+
page_response = PageResponse(**model_response_json)
279264

280265
if not page_response.is_rotation_valid and attempt < MAX_RETRIES - 1:
281266
logger.info(
@@ -581,10 +566,6 @@ async def worker(args, work_queue: WorkQueue, semaphore, worker_id):
581566

582567

583568
async def vllm_server_task(model_name_or_path, args, semaphore):
584-
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
585-
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Convert to GB
586-
mem_fraction_arg = ["--gpu-memory-utilization", "0.80"] if gpu_memory < 60 else []
587-
588569
cmd = [
589570
"vllm",
590571
"serve",
@@ -596,8 +577,11 @@ async def vllm_server_task(model_name_or_path, args, semaphore):
596577
"warning",
597578
"--served-model-name",
598579
"olmocr",
580+
"--tensor-parallel-size",
581+
str(args.tensor_parallel_size),
582+
"--data-parallel-size",
583+
str(args.data_parallel_size),
599584
]
600-
cmd.extend(mem_fraction_arg)
601585

602586
proc = await asyncio.create_subprocess_exec(
603587
*cmd,
@@ -637,7 +621,7 @@ async def process_line(line):
637621
if match:
638622
last_running_req = int(match.group(1))
639623

640-
match = re.search(r"Waiting: (\d+)", line)
624+
match = re.search(r"(?:Waiting|Pending):\s*(\d+)", line)
641625
if match:
642626
last_queue_req = int(match.group(1))
643627
logger.info(f"vllm running req: {last_running_req} queue req: {last_queue_req}")
@@ -1025,8 +1009,7 @@ async def main():
10251009
)
10261010
parser.add_argument("--model_max_context", type=int, default="8192", help="Maximum context length that the model was fine tuned under")
10271011
parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1024)
1028-
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters)", default=3000)
1029-
parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs")
1012+
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters)", default=6000)
10301013

10311014
# Beaker/job running stuff
10321015
parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
@@ -1039,6 +1022,8 @@ async def main():
10391022
parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
10401023
parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
10411024
parser.add_argument("--port", type=int, default=30024, help="Port to use for the VLLM server")
1025+
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="Tensor parallel size for vLLM")
1026+
parser.add_argument("--data-parallel-size", "-dp", type=int, default=1, help="Data parallel size for vLLM")
10421027
args = parser.parse_args()
10431028

10441029
global workspace_s3, pdf_s3
@@ -1239,4 +1224,4 @@ async def main():
12391224

12401225

12411226
if __name__ == "__main__":
1242-
asyncio.run(main())
1227+
asyncio.run(main())

0 commit comments

Comments
 (0)