Skip to content

Commit 0f733ff

Browse files
committed
FIxes to compare vllm script
1 parent 16145a4 commit 0f733ff

File tree

1 file changed

+59
-37
lines changed

1 file changed

+59
-37
lines changed

olmocr/train/compare_vllm_checkpoint.py

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
128128

129129
def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, sampling_params, device, args):
130130
"""Process a single prompt with image and return comparison results."""
131+
# Track if we found the first mismatch for max_prob_first_diff
132+
found_first_mismatch = False
133+
max_prob_first_diff = 0.0
131134
# Extract messages from the sample (which is the output of build_page_query)
132135
messages = sample['messages']
133136

@@ -215,7 +218,6 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
215218

216219
# Track mismatch info
217220
first_mismatch_idx = None
218-
max_prob_diff = 0.0
219221

220222
# Get all token IDs from the HF model's input
221223
all_token_ids = inputs["input_ids"][0].tolist()
@@ -255,11 +257,10 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
255257
if token_id != hf_argmax:
256258
if first_mismatch_idx is None:
257259
first_mismatch_idx = pos - len(prompt_token_ids)
258-
259-
# Calculate probability difference
260-
if vllm_prob is not None:
261-
prob_diff = abs(vllm_prob - hf_prob)
262-
max_prob_diff = max(max_prob_diff, prob_diff)
260+
# Calculate probability difference only for the first mismatch
261+
if vllm_prob is not None and not found_first_mismatch:
262+
max_prob_first_diff = abs(vllm_prob - hf_prob)
263+
found_first_mismatch = True
263264

264265
# Decode HF argmax token (only show if mismatch)
265266
hf_token_str = ""
@@ -290,14 +291,15 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
290291
# Report first mismatch index
291292
if first_mismatch_idx is not None:
292293
print(f"First mismatch at generation index: {first_mismatch_idx}")
293-
print(f"Max probability difference: {max_prob_diff:.6f}")
294+
print(f"First mismatch probability difference: {max_prob_first_diff:.6f}")
294295
else:
295296
print("No mismatches found in generated tokens")
296297

297298
return {
298299
'first_mismatch_idx': first_mismatch_idx,
299-
'max_prob_diff': max_prob_diff,
300-
'match_rate': match_rate
300+
'max_prob_first_diff': max_prob_first_diff,
301+
'match_rate': match_rate,
302+
'num_generated': len(generated_token_ids)
301303
}
302304

303305

@@ -329,24 +331,7 @@ async def async_main():
329331
# Load prompts and images
330332
samples = await load_pdf_prompts(num_samples=args.num_prompts, seed=args.seed)
331333

332-
# Create vLLM engine
333-
print("\n=== Creating vLLM Engine ===")
334-
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
335-
sampling_params = SamplingParams(
336-
temperature=args.temperature,
337-
max_tokens=args.max_tokens,
338-
logprobs=1 # Get top-1 logprobs
339-
)
340-
341-
# Get processor (VLMs use processor instead of tokenizer)
342-
# processor = llm.get_tokenizer() # Not needed, we get it later
343-
344-
# Clean up vLLM before loading HF model
345-
del llm
346-
gc.collect()
347-
torch.cuda.empty_cache()
348-
349-
# Load HuggingFace model and processor
334+
# Load HuggingFace model and processor first
350335
print("\n=== Loading HuggingFace Model ===")
351336
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
352337
processor_hf = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
@@ -358,36 +343,73 @@ async def async_main():
358343
)
359344
hf_model.eval()
360345

346+
# Create vLLM engine once
347+
print("\n=== Creating vLLM Engine ===")
348+
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
349+
sampling_params = SamplingParams(
350+
temperature=args.temperature,
351+
max_tokens=args.max_tokens,
352+
logprobs=1 # Get top-1 logprobs
353+
)
354+
361355
# Process samples until finding significant mismatch
362356
print("\n=== Processing Samples ===")
357+
358+
# Initialize statistics tracking
359+
all_results = []
363360
for i, sample in enumerate(samples):
364361
print(f"\n\n{'#'*80}")
365362
print(f"### Processing sample {i+1}/{len(samples)}")
366363
print(f"{'#'*80}")
367364

368-
# Recreate vLLM for each prompt
369-
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
370-
371365
# Process single sample
372366
result = process_single_prompt(sample, llm, hf_model, processor_hf, sampling_params, device, args)
373-
374-
# Clean up vLLM after each prompt
375-
del llm
376-
gc.collect()
377-
torch.cuda.empty_cache()
367+
all_results.append(result)
378368

379369
# Check if we found significant mismatch
380-
if result['first_mismatch_idx'] is not None and result['max_prob_diff'] > args.prob_threshold:
370+
if result['first_mismatch_idx'] is not None and result['max_prob_first_diff'] > args.prob_threshold:
381371
print(f"\n\n{'*'*80}")
382372
print(f"*** FOUND SIGNIFICANT MISMATCH ***")
383-
print(f"*** Max probability difference: {result['max_prob_diff']:.6f} > {args.prob_threshold} ***")
373+
print(f"*** First mismatch probability difference: {result['max_prob_first_diff']:.6f} > {args.prob_threshold} ***")
384374
print(f"*** Stopping after sample {i+1}/{len(samples)} ***")
385375
print(f"{'*'*80}")
386376
break
387377
else:
388378
print(f"\n\n{'='*80}")
389379
print(f"=== Processed all {len(samples)} samples without finding significant mismatch ===")
390380
print(f"{'='*80}")
381+
382+
# Report aggregated statistics
383+
print(f"\n\n{'='*80}")
384+
print("=== AGGREGATED STATISTICS ===")
385+
print(f"{'='*80}")
386+
387+
total_samples = len(all_results)
388+
samples_with_mismatches = sum(1 for r in all_results if r['first_mismatch_idx'] is not None)
389+
total_tokens_generated = sum(r['num_generated'] for r in all_results)
390+
391+
print(f"Total samples processed: {total_samples}")
392+
print(f"Samples with mismatches: {samples_with_mismatches} ({samples_with_mismatches/total_samples*100:.1f}%)")
393+
print(f"Total tokens generated: {total_tokens_generated}")
394+
395+
if samples_with_mismatches > 0:
396+
avg_match_rate = sum(r['match_rate'] for r in all_results) / total_samples
397+
max_prob_diffs = [r['max_prob_first_diff'] for r in all_results if r['first_mismatch_idx'] is not None]
398+
avg_prob_diff = sum(max_prob_diffs) / len(max_prob_diffs)
399+
max_prob_diff_overall = max(max_prob_diffs)
400+
401+
first_mismatch_positions = [r['first_mismatch_idx'] for r in all_results if r['first_mismatch_idx'] is not None]
402+
avg_first_mismatch_pos = sum(first_mismatch_positions) / len(first_mismatch_positions)
403+
404+
print(f"\nMismatch Statistics:")
405+
print(f" Average token match rate: {avg_match_rate:.1f}%")
406+
print(f" Average first mismatch position: {avg_first_mismatch_pos:.1f}")
407+
print(f" Average first mismatch prob diff: {avg_prob_diff:.6f}")
408+
print(f" Max first mismatch prob diff: {max_prob_diff_overall:.6f}")
409+
else:
410+
print("\nNo mismatches found in any samples!")
411+
412+
print(f"\n{'='*80}")
391413

392414

393415
def main():

0 commit comments

Comments
 (0)