5
5
"""
6
6
7
7
import argparse
8
+ import asyncio
8
9
import gc
10
+ import os
11
+ import glob
12
+ import tempfile
13
+ import shutil
9
14
import torch
10
15
from vllm import LLM , SamplingParams
11
16
from transformers import AutoProcessor , AutoModelForVision2Seq
12
17
from datasets import load_dataset
18
+ from huggingface_hub import snapshot_download
13
19
import random
14
20
import numpy as np
15
21
from typing import List , Dict
16
22
import base64
17
23
from io import BytesIO
24
+ import PIL .Image
25
+ import logging
26
+
27
+ from olmocr .pipeline import build_page_query
28
+ from olmocr .s3_utils import download_directory
29
+
30
+ logger = logging .getLogger (__name__ )
31
+
32
+
33
+ async def download_model (model_name_or_path : str , max_retries : int = 5 ):
34
+ for retry in range (max_retries ):
35
+ try :
36
+ if model_name_or_path .startswith ("s3://" ) or model_name_or_path .startswith ("gs://" ) or model_name_or_path .startswith ("weka://" ):
37
+ logger .info (f"Downloading model directory from '{ model_name_or_path } '" )
38
+ model_cache_dir = os .path .join (os .path .expanduser ("~" ), ".cache" , "olmocr" , "model" )
39
+ # Delete existing model cache directory if it exists
40
+ if os .path .exists (model_cache_dir ):
41
+ shutil .rmtree (model_cache_dir )
42
+ download_directory ([model_name_or_path ], model_cache_dir )
43
+ return model_cache_dir
44
+ elif os .path .isabs (model_name_or_path ) and os .path .isdir (model_name_or_path ):
45
+ logger .info (f"Using local model path at '{ model_name_or_path } '" )
46
+ return model_name_or_path
47
+ else :
48
+ logger .info (f"Downloading model with hugging face '{ model_name_or_path } '" )
49
+ snapshot_download (repo_id = model_name_or_path )
50
+ return model_name_or_path
51
+ except Exception :
52
+ if retry == max_retries - 1 :
53
+ raise # Raise on final attempt and fail the job
54
+ logger .warning (f"Model download failed (attempt { retry + 1 } /{ max_retries } ), retrying..." )
55
+ await asyncio .sleep (2 ** retry ) # Exponential backoff
18
56
19
57
20
58
def image_to_base64_data_url (image ):
@@ -25,75 +63,95 @@ def image_to_base64_data_url(image):
25
63
return f"data:image/png;base64,{ img_str } "
26
64
27
65
28
- def load_wildvision_prompts (num_samples : int = 100 , seed : int = 42 , max_length : int = 2048 ) -> List [Dict [str , str ]]:
29
- """Load prompts and images from WildVision-bench dataset with fixed random seed."""
30
- print (f"Loading WildVision-bench dataset with { num_samples } samples and seed { seed } " )
66
+ def load_pdf_prompts (num_samples : int = 100 , seed : int = 42 , max_length : int = 2048 ) -> List [Dict [str , str ]]:
67
+ """Load prompts and images from olmOCR-mix-0225-benchmarkset dataset with fixed random seed."""
68
+ print (f"Loading olmOCR-mix-0225-benchmarkset dataset with { num_samples } samples and seed { seed } " )
31
69
32
70
# Set random seed for reproducibility
33
71
random .seed (seed )
34
72
np .random .seed (seed )
35
73
36
- # Load dataset
37
- dataset = load_dataset ("WildVision/wildvision-bench" , "vision_bench_0701" , split = "test" , streaming = True )
38
-
39
- # Collect prompts and images
40
- samples = []
41
- examined = 0
42
- for example in dataset :
43
- examined += 1
44
- if len (samples ) >= num_samples * 2 : # Collect extra to allow filtering
45
- break
74
+ # Download dataset to a temporary directory
75
+ with tempfile .TemporaryDirectory () as temp_dir :
76
+ print ("Downloading dataset..." )
77
+ dataset_path = snapshot_download (
78
+ repo_id = "allenai/olmOCR-mix-0225-benchmarkset" ,
79
+ repo_type = "dataset" ,
80
+ local_dir = temp_dir ,
81
+ allow_patterns = "pdfs/*.pdf" # Only download PDF files
82
+ )
46
83
47
- # Extract prompt and image from the example
48
- prompt = example . get ( 'instruction' , '' ). strip ( )
49
- image = example . get ( 'image' , None )
84
+ # Find all PDF files in the pdfs directory
85
+ pdf_pattern = os . path . join ( dataset_path , "pdfs" , "*.pdf" )
86
+ pdf_files = glob . glob ( pdf_pattern )
50
87
51
- # Filter by prompt length and ensure we have both prompt and image
52
- if prompt and image and 10 < len (prompt ) <= max_length :
53
- samples .append ({
54
- 'prompt' : prompt ,
55
- 'image' : image # This is already a PIL Image object
56
- })
88
+ if not pdf_files :
89
+ raise ValueError (f"No PDF files found in { pdf_pattern } " )
57
90
58
- # Stop if we've examined too many without finding enough
59
- if examined > num_samples * 10 :
60
- break
61
-
62
- # Randomly sample exactly num_samples
63
- if len (samples ) < num_samples :
64
- print (f"Only found { len (samples )} valid samples out of { examined } examined" )
65
- if len (samples ) == 0 :
66
- raise ValueError ("No valid samples found in dataset" )
67
- samples = random .choices (samples , k = num_samples )
68
- else :
69
- samples = random .sample (samples , num_samples )
70
-
71
- print (f"Selected { len (samples )} samples for comparison" )
72
- return samples
73
-
91
+ print (f"Found { len (pdf_files )} PDF files" )
92
+
93
+ # Randomly sample num_samples PDFs
94
+ if len (pdf_files ) > num_samples :
95
+ sampled_pdfs = random .sample (pdf_files , num_samples )
96
+ else :
97
+ sampled_pdfs = pdf_files
98
+ print (f"Warning: Only { len (pdf_files )} PDFs available, less than requested { num_samples } " )
99
+
100
+ # Process each PDF and build queries
101
+ queries = []
102
+ for pdf_path in sampled_pdfs :
103
+ try :
104
+ # Build page query for page 1 of each PDF
105
+ query = asyncio .run (build_page_query (
106
+ local_pdf_path = pdf_path ,
107
+ page = 1 ,
108
+ target_longest_image_dim = 1280 ,
109
+ image_rotation = 0
110
+ ))
111
+ queries .append (query )
112
+ except Exception as e :
113
+ print (f"Error processing { os .path .basename (pdf_path )} : { e } " )
114
+ continue
115
+
116
+ print (f"Successfully processed { len (queries )} PDFs" )
117
+ return queries
74
118
75
119
def process_single_prompt (sample : Dict [str , any ], llm , hf_model , processor , sampling_params , device , args ):
76
120
"""Process a single prompt with image and return comparison results."""
77
- prompt = sample ['prompt' ]
78
- image = sample ['image' ] # Already a PIL Image object
121
+ # Extract messages from the sample (which is the output of build_page_query)
122
+ messages = sample ['messages' ]
123
+
124
+ # Extract the text prompt and image from the messages
125
+ user_message = messages [0 ]
126
+ text_prompt = None
127
+ image_base64 = None
128
+
129
+ for content in user_message ['content' ]:
130
+ if content ['type' ] == 'text' :
131
+ text_prompt = content ['text' ]
132
+ elif content ['type' ] == 'image_url' :
133
+ image_url = content ['image_url' ]['url' ]
134
+ # Extract base64 data after the comma
135
+ if ',' in image_url :
136
+ image_base64 = image_url .split (',' )[1 ]
137
+ else :
138
+ image_base64 = image_url
139
+
140
+ if text_prompt is None or image_base64 is None :
141
+ raise ValueError ("Failed to extract text prompt or image from messages" )
142
+
143
+ # Decode the base64 image to PIL Image
144
+ image_bytes = base64 .b64decode (image_base64 )
145
+ image = PIL .Image .open (BytesIO (image_bytes ))
79
146
80
147
print (f"\n { '=' * 80 } " )
81
- print (f"PROMPT: { prompt [:100 ]} ..." if len (prompt ) > 100 else f"PROMPT: { prompt } " )
148
+ print (f"PROMPT: { text_prompt [:100 ]} ..." if len (text_prompt ) > 100 else f"PROMPT: { text_prompt } " )
82
149
print (f"IMAGE: { image .size } { image .mode } " )
83
150
84
151
# Generate with vLLM
85
152
print ("\n === vLLM Generation ===" )
86
- # Convert image to base64 data URL
87
- image_data_url = image_to_base64_data_url (image )
88
153
89
- # For VLMs, vLLM expects the message format with image
90
- messages = [{
91
- "role" : "user" ,
92
- "content" : [
93
- {"type" : "image_url" , "image_url" : {"url" : image_data_url }},
94
- {"type" : "text" , "text" : prompt }
95
- ]
96
- }]
154
+ # For VLLM, use the messages just as comes out of build_page_query
97
155
outputs = llm .chat (messages , sampling_params )
98
156
output = outputs [0 ]
99
157
@@ -110,19 +168,19 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
110
168
111
169
# HuggingFace forward pass
112
170
print ("\n === HuggingFace Forward Pass ===" )
113
- # Prepare inputs for HF model
171
+ # Prepare inputs for HF model using the extracted image and text
114
172
conversation = [
115
173
{
116
174
"role" : "user" ,
117
175
"content" : [
118
176
{"type" : "image" },
119
- {"type" : "text" , "text" : prompt }
177
+ {"type" : "text" , "text" : text_prompt }
120
178
]
121
179
}
122
180
]
123
- text_prompt = processor .apply_chat_template (conversation , add_generation_prompt = True , tokenize = False )
181
+ hf_text_prompt = processor .apply_chat_template (conversation , add_generation_prompt = True , tokenize = False )
124
182
inputs = processor (
125
- text = [text_prompt ],
183
+ text = [hf_text_prompt ],
126
184
images = [image ],
127
185
return_tensors = "pt"
128
186
).to (device )
@@ -233,7 +291,7 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
233
291
}
234
292
235
293
236
- def main ():
294
+ async def async_main ():
237
295
parser = argparse .ArgumentParser (description = "Batch compare VLM inference between vLLM and HuggingFace" )
238
296
parser .add_argument ("--model" , type = str , default = "Qwen/Qwen2.5-VL-7B-Instruct" ,
239
297
help = "Model name or path" )
@@ -253,14 +311,17 @@ def main():
253
311
print (f"Max tokens: { args .max_tokens } " )
254
312
print (f"Temperature: { args .temperature } " )
255
313
print (f"Probability threshold: { args .prob_threshold } " )
256
- print (f"Loading { args .num_prompts } samples from WildVision-bench\n " )
314
+ print (f"Loading { args .num_prompts } samples from olmOCR-mix-0225-benchmarkset\n " )
315
+
316
+ # Download the model before loading prompts
317
+ model_path = await download_model (args .model )
257
318
258
319
# Load prompts and images
259
- samples = load_wildvision_prompts (num_samples = args .num_prompts , seed = args .seed )
320
+ samples = load_pdf_prompts (num_samples = args .num_prompts , seed = args .seed )
260
321
261
322
# Create vLLM engine
262
323
print ("\n === Creating vLLM Engine ===" )
263
- llm = LLM (model = args . model , trust_remote_code = True , gpu_memory_utilization = 0.5 )
324
+ llm = LLM (model = model_path , trust_remote_code = True , gpu_memory_utilization = 0.5 )
264
325
sampling_params = SamplingParams (
265
326
temperature = args .temperature ,
266
327
max_tokens = args .max_tokens ,
@@ -278,9 +339,9 @@ def main():
278
339
# Load HuggingFace model and processor
279
340
print ("\n === Loading HuggingFace Model ===" )
280
341
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
281
- processor_hf = AutoProcessor .from_pretrained (args . model , trust_remote_code = True )
342
+ processor_hf = AutoProcessor .from_pretrained (model_path , trust_remote_code = True )
282
343
hf_model = AutoModelForVision2Seq .from_pretrained (
283
- args . model ,
344
+ model_path ,
284
345
trust_remote_code = True ,
285
346
torch_dtype = torch .float16 ,
286
347
device_map = "auto"
@@ -295,7 +356,7 @@ def main():
295
356
print (f"{ '#' * 80 } " )
296
357
297
358
# Recreate vLLM for each prompt
298
- llm = LLM (model = args . model , trust_remote_code = True , gpu_memory_utilization = 0.5 )
359
+ llm = LLM (model = model_path , trust_remote_code = True , gpu_memory_utilization = 0.5 )
299
360
300
361
# Process single sample
301
362
result = process_single_prompt (sample , llm , hf_model , processor_hf , sampling_params , device , args )
@@ -319,5 +380,9 @@ def main():
319
380
print (f"{ '=' * 80 } " )
320
381
321
382
383
+ def main ():
384
+ asyncio .run (async_main ())
385
+
386
+
322
387
if __name__ == "__main__" :
323
388
main ()
0 commit comments