@@ -128,6 +128,9 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
128
128
129
129
def process_single_prompt (sample : Dict [str , any ], llm , hf_model , processor , sampling_params , device , args ):
130
130
"""Process a single prompt with image and return comparison results."""
131
+ # Track if we found the first mismatch for max_prob_first_diff
132
+ found_first_mismatch = False
133
+ max_prob_first_diff = 0.0
131
134
# Extract messages from the sample (which is the output of build_page_query)
132
135
messages = sample ['messages' ]
133
136
@@ -215,7 +218,6 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
215
218
216
219
# Track mismatch info
217
220
first_mismatch_idx = None
218
- max_prob_diff = 0.0
219
221
220
222
# Get all token IDs from the HF model's input
221
223
all_token_ids = inputs ["input_ids" ][0 ].tolist ()
@@ -255,11 +257,10 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
255
257
if token_id != hf_argmax :
256
258
if first_mismatch_idx is None :
257
259
first_mismatch_idx = pos - len (prompt_token_ids )
258
-
259
- # Calculate probability difference
260
- if vllm_prob is not None :
261
- prob_diff = abs (vllm_prob - hf_prob )
262
- max_prob_diff = max (max_prob_diff , prob_diff )
260
+ # Calculate probability difference only for the first mismatch
261
+ if vllm_prob is not None and not found_first_mismatch :
262
+ max_prob_first_diff = abs (vllm_prob - hf_prob )
263
+ found_first_mismatch = True
263
264
264
265
# Decode HF argmax token (only show if mismatch)
265
266
hf_token_str = ""
@@ -290,14 +291,15 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
290
291
# Report first mismatch index
291
292
if first_mismatch_idx is not None :
292
293
print (f"First mismatch at generation index: { first_mismatch_idx } " )
293
- print (f"Max probability difference: { max_prob_diff :.6f} " )
294
+ print (f"First mismatch probability difference: { max_prob_first_diff :.6f} " )
294
295
else :
295
296
print ("No mismatches found in generated tokens" )
296
297
297
298
return {
298
299
'first_mismatch_idx' : first_mismatch_idx ,
299
- 'max_prob_diff' : max_prob_diff ,
300
- 'match_rate' : match_rate
300
+ 'max_prob_first_diff' : max_prob_first_diff ,
301
+ 'match_rate' : match_rate ,
302
+ 'num_generated' : len (generated_token_ids )
301
303
}
302
304
303
305
@@ -329,24 +331,7 @@ async def async_main():
329
331
# Load prompts and images
330
332
samples = await load_pdf_prompts (num_samples = args .num_prompts , seed = args .seed )
331
333
332
- # Create vLLM engine
333
- print ("\n === Creating vLLM Engine ===" )
334
- llm = LLM (model = model_path , trust_remote_code = True , gpu_memory_utilization = 0.5 )
335
- sampling_params = SamplingParams (
336
- temperature = args .temperature ,
337
- max_tokens = args .max_tokens ,
338
- logprobs = 1 # Get top-1 logprobs
339
- )
340
-
341
- # Get processor (VLMs use processor instead of tokenizer)
342
- # processor = llm.get_tokenizer() # Not needed, we get it later
343
-
344
- # Clean up vLLM before loading HF model
345
- del llm
346
- gc .collect ()
347
- torch .cuda .empty_cache ()
348
-
349
- # Load HuggingFace model and processor
334
+ # Load HuggingFace model and processor first
350
335
print ("\n === Loading HuggingFace Model ===" )
351
336
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
352
337
processor_hf = AutoProcessor .from_pretrained (model_path , trust_remote_code = True )
@@ -358,36 +343,73 @@ async def async_main():
358
343
)
359
344
hf_model .eval ()
360
345
346
+ # Create vLLM engine once
347
+ print ("\n === Creating vLLM Engine ===" )
348
+ llm = LLM (model = model_path , trust_remote_code = True , gpu_memory_utilization = 0.5 )
349
+ sampling_params = SamplingParams (
350
+ temperature = args .temperature ,
351
+ max_tokens = args .max_tokens ,
352
+ logprobs = 1 # Get top-1 logprobs
353
+ )
354
+
361
355
# Process samples until finding significant mismatch
362
356
print ("\n === Processing Samples ===" )
357
+
358
+ # Initialize statistics tracking
359
+ all_results = []
363
360
for i , sample in enumerate (samples ):
364
361
print (f"\n \n { '#' * 80 } " )
365
362
print (f"### Processing sample { i + 1 } /{ len (samples )} " )
366
363
print (f"{ '#' * 80 } " )
367
364
368
- # Recreate vLLM for each prompt
369
- llm = LLM (model = model_path , trust_remote_code = True , gpu_memory_utilization = 0.5 )
370
-
371
365
# Process single sample
372
366
result = process_single_prompt (sample , llm , hf_model , processor_hf , sampling_params , device , args )
373
-
374
- # Clean up vLLM after each prompt
375
- del llm
376
- gc .collect ()
377
- torch .cuda .empty_cache ()
367
+ all_results .append (result )
378
368
379
369
# Check if we found significant mismatch
380
- if result ['first_mismatch_idx' ] is not None and result ['max_prob_diff ' ] > args .prob_threshold :
370
+ if result ['first_mismatch_idx' ] is not None and result ['max_prob_first_diff ' ] > args .prob_threshold :
381
371
print (f"\n \n { '*' * 80 } " )
382
372
print (f"*** FOUND SIGNIFICANT MISMATCH ***" )
383
- print (f"*** Max probability difference: { result ['max_prob_diff ' ]:.6f} > { args .prob_threshold } ***" )
373
+ print (f"*** First mismatch probability difference: { result ['max_prob_first_diff ' ]:.6f} > { args .prob_threshold } ***" )
384
374
print (f"*** Stopping after sample { i + 1 } /{ len (samples )} ***" )
385
375
print (f"{ '*' * 80 } " )
386
376
break
387
377
else :
388
378
print (f"\n \n { '=' * 80 } " )
389
379
print (f"=== Processed all { len (samples )} samples without finding significant mismatch ===" )
390
380
print (f"{ '=' * 80 } " )
381
+
382
+ # Report aggregated statistics
383
+ print (f"\n \n { '=' * 80 } " )
384
+ print ("=== AGGREGATED STATISTICS ===" )
385
+ print (f"{ '=' * 80 } " )
386
+
387
+ total_samples = len (all_results )
388
+ samples_with_mismatches = sum (1 for r in all_results if r ['first_mismatch_idx' ] is not None )
389
+ total_tokens_generated = sum (r ['num_generated' ] for r in all_results )
390
+
391
+ print (f"Total samples processed: { total_samples } " )
392
+ print (f"Samples with mismatches: { samples_with_mismatches } ({ samples_with_mismatches / total_samples * 100 :.1f} %)" )
393
+ print (f"Total tokens generated: { total_tokens_generated } " )
394
+
395
+ if samples_with_mismatches > 0 :
396
+ avg_match_rate = sum (r ['match_rate' ] for r in all_results ) / total_samples
397
+ max_prob_diffs = [r ['max_prob_first_diff' ] for r in all_results if r ['first_mismatch_idx' ] is not None ]
398
+ avg_prob_diff = sum (max_prob_diffs ) / len (max_prob_diffs )
399
+ max_prob_diff_overall = max (max_prob_diffs )
400
+
401
+ first_mismatch_positions = [r ['first_mismatch_idx' ] for r in all_results if r ['first_mismatch_idx' ] is not None ]
402
+ avg_first_mismatch_pos = sum (first_mismatch_positions ) / len (first_mismatch_positions )
403
+
404
+ print (f"\n Mismatch Statistics:" )
405
+ print (f" Average token match rate: { avg_match_rate :.1f} %" )
406
+ print (f" Average first mismatch position: { avg_first_mismatch_pos :.1f} " )
407
+ print (f" Average first mismatch prob diff: { avg_prob_diff :.6f} " )
408
+ print (f" Max first mismatch prob diff: { max_prob_diff_overall :.6f} " )
409
+ else :
410
+ print ("\n No mismatches found in any samples!" )
411
+
412
+ print (f"\n { '=' * 80 } " )
391
413
392
414
393
415
def main ():
0 commit comments