Skip to content

Commit 939a76a

Browse files
committed
Adding a compare vllm checkpoint script
1 parent 2460895 commit 939a76a

File tree

1 file changed

+127
-62
lines changed

1 file changed

+127
-62
lines changed

olmocr/train/compare_vllm_checkpoint.py

Lines changed: 127 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,54 @@
55
"""
66

77
import argparse
8+
import asyncio
89
import gc
10+
import os
11+
import glob
12+
import tempfile
13+
import shutil
914
import torch
1015
from vllm import LLM, SamplingParams
1116
from transformers import AutoProcessor, AutoModelForVision2Seq
1217
from datasets import load_dataset
18+
from huggingface_hub import snapshot_download
1319
import random
1420
import numpy as np
1521
from typing import List, Dict
1622
import base64
1723
from io import BytesIO
24+
import PIL.Image
25+
import logging
26+
27+
from olmocr.pipeline import build_page_query
28+
from olmocr.s3_utils import download_directory
29+
30+
logger = logging.getLogger(__name__)
31+
32+
33+
async def download_model(model_name_or_path: str, max_retries: int = 5):
34+
for retry in range(max_retries):
35+
try:
36+
if model_name_or_path.startswith("s3://") or model_name_or_path.startswith("gs://") or model_name_or_path.startswith("weka://"):
37+
logger.info(f"Downloading model directory from '{model_name_or_path}'")
38+
model_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "olmocr", "model")
39+
# Delete existing model cache directory if it exists
40+
if os.path.exists(model_cache_dir):
41+
shutil.rmtree(model_cache_dir)
42+
download_directory([model_name_or_path], model_cache_dir)
43+
return model_cache_dir
44+
elif os.path.isabs(model_name_or_path) and os.path.isdir(model_name_or_path):
45+
logger.info(f"Using local model path at '{model_name_or_path}'")
46+
return model_name_or_path
47+
else:
48+
logger.info(f"Downloading model with hugging face '{model_name_or_path}'")
49+
snapshot_download(repo_id=model_name_or_path)
50+
return model_name_or_path
51+
except Exception:
52+
if retry == max_retries - 1:
53+
raise # Raise on final attempt and fail the job
54+
logger.warning(f"Model download failed (attempt {retry + 1}/{max_retries}), retrying...")
55+
await asyncio.sleep(2 ** retry) # Exponential backoff
1856

1957

2058
def image_to_base64_data_url(image):
@@ -25,75 +63,95 @@ def image_to_base64_data_url(image):
2563
return f"data:image/png;base64,{img_str}"
2664

2765

28-
def load_wildvision_prompts(num_samples: int = 100, seed: int = 42, max_length: int = 2048) -> List[Dict[str, str]]:
29-
"""Load prompts and images from WildVision-bench dataset with fixed random seed."""
30-
print(f"Loading WildVision-bench dataset with {num_samples} samples and seed {seed}")
66+
def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: int = 2048) -> List[Dict[str, str]]:
67+
"""Load prompts and images from olmOCR-mix-0225-benchmarkset dataset with fixed random seed."""
68+
print(f"Loading olmOCR-mix-0225-benchmarkset dataset with {num_samples} samples and seed {seed}")
3169

3270
# Set random seed for reproducibility
3371
random.seed(seed)
3472
np.random.seed(seed)
3573

36-
# Load dataset
37-
dataset = load_dataset("WildVision/wildvision-bench", "vision_bench_0701", split="test", streaming=True)
38-
39-
# Collect prompts and images
40-
samples = []
41-
examined = 0
42-
for example in dataset:
43-
examined += 1
44-
if len(samples) >= num_samples * 2: # Collect extra to allow filtering
45-
break
74+
# Download dataset to a temporary directory
75+
with tempfile.TemporaryDirectory() as temp_dir:
76+
print("Downloading dataset...")
77+
dataset_path = snapshot_download(
78+
repo_id="allenai/olmOCR-mix-0225-benchmarkset",
79+
repo_type="dataset",
80+
local_dir=temp_dir,
81+
allow_patterns="pdfs/*.pdf" # Only download PDF files
82+
)
4683

47-
# Extract prompt and image from the example
48-
prompt = example.get('instruction', '').strip()
49-
image = example.get('image', None)
84+
# Find all PDF files in the pdfs directory
85+
pdf_pattern = os.path.join(dataset_path, "pdfs", "*.pdf")
86+
pdf_files = glob.glob(pdf_pattern)
5087

51-
# Filter by prompt length and ensure we have both prompt and image
52-
if prompt and image and 10 < len(prompt) <= max_length:
53-
samples.append({
54-
'prompt': prompt,
55-
'image': image # This is already a PIL Image object
56-
})
88+
if not pdf_files:
89+
raise ValueError(f"No PDF files found in {pdf_pattern}")
5790

58-
# Stop if we've examined too many without finding enough
59-
if examined > num_samples * 10:
60-
break
61-
62-
# Randomly sample exactly num_samples
63-
if len(samples) < num_samples:
64-
print(f"Only found {len(samples)} valid samples out of {examined} examined")
65-
if len(samples) == 0:
66-
raise ValueError("No valid samples found in dataset")
67-
samples = random.choices(samples, k=num_samples)
68-
else:
69-
samples = random.sample(samples, num_samples)
70-
71-
print(f"Selected {len(samples)} samples for comparison")
72-
return samples
73-
91+
print(f"Found {len(pdf_files)} PDF files")
92+
93+
# Randomly sample num_samples PDFs
94+
if len(pdf_files) > num_samples:
95+
sampled_pdfs = random.sample(pdf_files, num_samples)
96+
else:
97+
sampled_pdfs = pdf_files
98+
print(f"Warning: Only {len(pdf_files)} PDFs available, less than requested {num_samples}")
99+
100+
# Process each PDF and build queries
101+
queries = []
102+
for pdf_path in sampled_pdfs:
103+
try:
104+
# Build page query for page 1 of each PDF
105+
query = asyncio.run(build_page_query(
106+
local_pdf_path=pdf_path,
107+
page=1,
108+
target_longest_image_dim=1280,
109+
image_rotation=0
110+
))
111+
queries.append(query)
112+
except Exception as e:
113+
print(f"Error processing {os.path.basename(pdf_path)}: {e}")
114+
continue
115+
116+
print(f"Successfully processed {len(queries)} PDFs")
117+
return queries
74118

75119
def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, sampling_params, device, args):
76120
"""Process a single prompt with image and return comparison results."""
77-
prompt = sample['prompt']
78-
image = sample['image'] # Already a PIL Image object
121+
# Extract messages from the sample (which is the output of build_page_query)
122+
messages = sample['messages']
123+
124+
# Extract the text prompt and image from the messages
125+
user_message = messages[0]
126+
text_prompt = None
127+
image_base64 = None
128+
129+
for content in user_message['content']:
130+
if content['type'] == 'text':
131+
text_prompt = content['text']
132+
elif content['type'] == 'image_url':
133+
image_url = content['image_url']['url']
134+
# Extract base64 data after the comma
135+
if ',' in image_url:
136+
image_base64 = image_url.split(',')[1]
137+
else:
138+
image_base64 = image_url
139+
140+
if text_prompt is None or image_base64 is None:
141+
raise ValueError("Failed to extract text prompt or image from messages")
142+
143+
# Decode the base64 image to PIL Image
144+
image_bytes = base64.b64decode(image_base64)
145+
image = PIL.Image.open(BytesIO(image_bytes))
79146

80147
print(f"\n{'='*80}")
81-
print(f"PROMPT: {prompt[:100]}..." if len(prompt) > 100 else f"PROMPT: {prompt}")
148+
print(f"PROMPT: {text_prompt[:100]}..." if len(text_prompt) > 100 else f"PROMPT: {text_prompt}")
82149
print(f"IMAGE: {image.size} {image.mode}")
83150

84151
# Generate with vLLM
85152
print("\n=== vLLM Generation ===")
86-
# Convert image to base64 data URL
87-
image_data_url = image_to_base64_data_url(image)
88153

89-
# For VLMs, vLLM expects the message format with image
90-
messages = [{
91-
"role": "user",
92-
"content": [
93-
{"type": "image_url", "image_url": {"url": image_data_url}},
94-
{"type": "text", "text": prompt}
95-
]
96-
}]
154+
# For VLLM, use the messages just as comes out of build_page_query
97155
outputs = llm.chat(messages, sampling_params)
98156
output = outputs[0]
99157

@@ -110,19 +168,19 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
110168

111169
# HuggingFace forward pass
112170
print("\n=== HuggingFace Forward Pass ===")
113-
# Prepare inputs for HF model
171+
# Prepare inputs for HF model using the extracted image and text
114172
conversation = [
115173
{
116174
"role": "user",
117175
"content": [
118176
{"type": "image"},
119-
{"type": "text", "text": prompt}
177+
{"type": "text", "text": text_prompt}
120178
]
121179
}
122180
]
123-
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
181+
hf_text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
124182
inputs = processor(
125-
text=[text_prompt],
183+
text=[hf_text_prompt],
126184
images=[image],
127185
return_tensors="pt"
128186
).to(device)
@@ -233,7 +291,7 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
233291
}
234292

235293

236-
def main():
294+
async def async_main():
237295
parser = argparse.ArgumentParser(description="Batch compare VLM inference between vLLM and HuggingFace")
238296
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct",
239297
help="Model name or path")
@@ -253,14 +311,17 @@ def main():
253311
print(f"Max tokens: {args.max_tokens}")
254312
print(f"Temperature: {args.temperature}")
255313
print(f"Probability threshold: {args.prob_threshold}")
256-
print(f"Loading {args.num_prompts} samples from WildVision-bench\n")
314+
print(f"Loading {args.num_prompts} samples from olmOCR-mix-0225-benchmarkset\n")
315+
316+
# Download the model before loading prompts
317+
model_path = await download_model(args.model)
257318

258319
# Load prompts and images
259-
samples = load_wildvision_prompts(num_samples=args.num_prompts, seed=args.seed)
320+
samples = load_pdf_prompts(num_samples=args.num_prompts, seed=args.seed)
260321

261322
# Create vLLM engine
262323
print("\n=== Creating vLLM Engine ===")
263-
llm = LLM(model=args.model, trust_remote_code=True, gpu_memory_utilization=0.5)
324+
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
264325
sampling_params = SamplingParams(
265326
temperature=args.temperature,
266327
max_tokens=args.max_tokens,
@@ -278,9 +339,9 @@ def main():
278339
# Load HuggingFace model and processor
279340
print("\n=== Loading HuggingFace Model ===")
280341
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
281-
processor_hf = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
342+
processor_hf = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
282343
hf_model = AutoModelForVision2Seq.from_pretrained(
283-
args.model,
344+
model_path,
284345
trust_remote_code=True,
285346
torch_dtype=torch.float16,
286347
device_map="auto"
@@ -295,7 +356,7 @@ def main():
295356
print(f"{'#'*80}")
296357

297358
# Recreate vLLM for each prompt
298-
llm = LLM(model=args.model, trust_remote_code=True, gpu_memory_utilization=0.5)
359+
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
299360

300361
# Process single sample
301362
result = process_single_prompt(sample, llm, hf_model, processor_hf, sampling_params, device, args)
@@ -319,5 +380,9 @@ def main():
319380
print(f"{'='*80}")
320381

321382

383+
def main():
384+
asyncio.run(async_main())
385+
386+
322387
if __name__ == "__main__":
323388
main()

0 commit comments

Comments
 (0)