23
23
24
24
import boto3
25
25
import httpx
26
- import torch
27
26
from botocore .exceptions import ClientError
28
27
from huggingface_hub import snapshot_download
29
28
from PIL import Image
@@ -236,10 +235,6 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
236
235
# Change temperature as number of attempts increases to overcome repetition issues at expense of quality
237
236
query ["temperature" ] = TEMPERATURE_BY_ATTEMPT [lookup_attempt ]
238
237
239
- # Enable guided decoding regex if needed
240
- if args .guided_decoding :
241
- query ["guided_regex" ] = r"---\nprimary_language: (?:[a-z]{2}|null)\nis_rotation_valid: (?:True|False|true|false)\nrotation_correction: (?:0|90|180|270)\nis_table: (?:True|False|true|false)\nis_diagram: (?:True|False|true|false)\n---\n[\s\S]*"
242
-
243
238
logger .info (f"Built page query for { pdf_orig_path } -{ page_num } " )
244
239
245
240
try :
@@ -258,24 +253,14 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
258
253
local_anchor_text_len = max (1 , local_anchor_text_len // 2 )
259
254
logger .info (f"Reducing anchor text len to { local_anchor_text_len } for { pdf_orig_path } -{ page_num } " )
260
255
raise ValueError ("Response exceeded model_max_context, cannot use this response" )
261
-
262
- if base_response_data ["choices" ][0 ]["finish_reason" ] != "stop" :
263
- local_anchor_text_len = max (1 , local_anchor_text_len // 2 )
264
- logger .info (f"Reducing anchor text len to { local_anchor_text_len } for { pdf_orig_path } -{ page_num } " )
265
- raise ValueError ("Response did not finish with reason code 'stop', cannot use this response" )
266
256
267
257
metrics .add_metrics (
268
258
server_input_tokens = base_response_data ["usage" ].get ("prompt_tokens" , 0 ),
269
259
server_output_tokens = base_response_data ["usage" ].get ("completion_tokens" , 0 ),
270
260
)
271
261
272
- model_response_markdown = base_response_data ["choices" ][0 ]["message" ]["content" ]
273
-
274
- # Somewhat temporary code, will need to refactor
275
- from olmocr .train .dataloader import FrontMatterParser
276
- parser = FrontMatterParser (front_matter_class = PageResponse )
277
- front_matter , text = parser ._extract_front_matter_and_text (model_response_markdown )
278
- page_response = parser ._parse_front_matter (front_matter , text )
262
+ model_response_json = json .loads (base_response_data ["choices" ][0 ]["message" ]["content" ])
263
+ page_response = PageResponse (** model_response_json )
279
264
280
265
if not page_response .is_rotation_valid and attempt < MAX_RETRIES - 1 :
281
266
logger .info (
@@ -581,10 +566,6 @@ async def worker(args, work_queue: WorkQueue, semaphore, worker_id):
581
566
582
567
583
568
async def vllm_server_task (model_name_or_path , args , semaphore ):
584
- # Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
585
- gpu_memory = torch .cuda .get_device_properties (0 ).total_memory / (1024 ** 3 ) # Convert to GB
586
- mem_fraction_arg = ["--gpu-memory-utilization" , "0.80" ] if gpu_memory < 60 else []
587
-
588
569
cmd = [
589
570
"vllm" ,
590
571
"serve" ,
@@ -596,8 +577,11 @@ async def vllm_server_task(model_name_or_path, args, semaphore):
596
577
"warning" ,
597
578
"--served-model-name" ,
598
579
"olmocr" ,
580
+ "--tensor-parallel-size" ,
581
+ str (args .tensor_parallel_size ),
582
+ "--data-parallel-size" ,
583
+ str (args .data_parallel_size ),
599
584
]
600
- cmd .extend (mem_fraction_arg )
601
585
602
586
proc = await asyncio .create_subprocess_exec (
603
587
* cmd ,
@@ -637,7 +621,7 @@ async def process_line(line):
637
621
if match :
638
622
last_running_req = int (match .group (1 ))
639
623
640
- match = re .search (r"Waiting: (\d+)" , line )
624
+ match = re .search (r"(?: Waiting|Pending):\s* (\d+)" , line )
641
625
if match :
642
626
last_queue_req = int (match .group (1 ))
643
627
logger .info (f"vllm running req: { last_running_req } queue req: { last_queue_req } " )
@@ -1025,8 +1009,7 @@ async def main():
1025
1009
)
1026
1010
parser .add_argument ("--model_max_context" , type = int , default = "8192" , help = "Maximum context length that the model was fine tuned under" )
1027
1011
parser .add_argument ("--target_longest_image_dim" , type = int , help = "Dimension on longest side to use for rendering the pdf pages" , default = 1024 )
1028
- parser .add_argument ("--target_anchor_text_len" , type = int , help = "Maximum amount of anchor text to use (characters)" , default = 3000 )
1029
- parser .add_argument ("--guided_decoding" , action = "store_true" , help = "Enable guided decoding for model YAML type outputs" )
1012
+ parser .add_argument ("--target_anchor_text_len" , type = int , help = "Maximum amount of anchor text to use (characters)" , default = 6000 )
1030
1013
1031
1014
# Beaker/job running stuff
1032
1015
parser .add_argument ("--beaker" , action = "store_true" , help = "Submit this job to beaker instead of running locally" )
@@ -1039,6 +1022,8 @@ async def main():
1039
1022
parser .add_argument ("--beaker_gpus" , type = int , default = 1 , help = "Number of gpu replicas to run" )
1040
1023
parser .add_argument ("--beaker_priority" , type = str , default = "normal" , help = "Beaker priority level for the job" )
1041
1024
parser .add_argument ("--port" , type = int , default = 30024 , help = "Port to use for the VLLM server" )
1025
+ parser .add_argument ("--tensor-parallel-size" , "-tp" , type = int , default = 1 , help = "Tensor parallel size for vLLM" )
1026
+ parser .add_argument ("--data-parallel-size" , "-dp" , type = int , default = 1 , help = "Data parallel size for vLLM" )
1042
1027
args = parser .parse_args ()
1043
1028
1044
1029
global workspace_s3 , pdf_s3
@@ -1239,4 +1224,4 @@ async def main():
1239
1224
1240
1225
1241
1226
if __name__ == "__main__" :
1242
- asyncio .run (main ())
1227
+ asyncio .run (main ())
0 commit comments