1
+ #!/usr/bin/env python3
2
+ """
3
+ Batch VLM inference comparison between vLLM and HuggingFace.
4
+ Processes prompts and images from WildVision-bench until finding significant mismatch.
5
+ """
6
+
7
+ import argparse
8
+ import gc
9
+ import torch
10
+ from vllm import LLM , SamplingParams
11
+ from transformers import AutoProcessor , AutoModelForVision2Seq
12
+ from datasets import load_dataset
13
+ import random
14
+ import numpy as np
15
+ from typing import List , Dict
16
+ import base64
17
+ from io import BytesIO
18
+
19
+
20
+ def image_to_base64_data_url (image ):
21
+ """Convert PIL image to base64 data URL."""
22
+ buffered = BytesIO ()
23
+ image .save (buffered , format = "PNG" )
24
+ img_str = base64 .b64encode (buffered .getvalue ()).decode ()
25
+ return f"data:image/png;base64,{ img_str } "
26
+
27
+
28
+ def load_wildvision_prompts (num_samples : int = 100 , seed : int = 42 , max_length : int = 2048 ) -> List [Dict [str , str ]]:
29
+ """Load prompts and images from WildVision-bench dataset with fixed random seed."""
30
+ print (f"Loading WildVision-bench dataset with { num_samples } samples and seed { seed } " )
31
+
32
+ # Set random seed for reproducibility
33
+ random .seed (seed )
34
+ np .random .seed (seed )
35
+
36
+ # Load dataset
37
+ dataset = load_dataset ("WildVision/wildvision-bench" , "vision_bench_0701" , split = "test" , streaming = True )
38
+
39
+ # Collect prompts and images
40
+ samples = []
41
+ examined = 0
42
+ for example in dataset :
43
+ examined += 1
44
+ if len (samples ) >= num_samples * 2 : # Collect extra to allow filtering
45
+ break
46
+
47
+ # Extract prompt and image from the example
48
+ prompt = example .get ('instruction' , '' ).strip ()
49
+ image = example .get ('image' , None )
50
+
51
+ # Filter by prompt length and ensure we have both prompt and image
52
+ if prompt and image and 10 < len (prompt ) <= max_length :
53
+ samples .append ({
54
+ 'prompt' : prompt ,
55
+ 'image' : image # This is already a PIL Image object
56
+ })
57
+
58
+ # Stop if we've examined too many without finding enough
59
+ if examined > num_samples * 10 :
60
+ break
61
+
62
+ # Randomly sample exactly num_samples
63
+ if len (samples ) < num_samples :
64
+ print (f"Only found { len (samples )} valid samples out of { examined } examined" )
65
+ if len (samples ) == 0 :
66
+ raise ValueError ("No valid samples found in dataset" )
67
+ samples = random .choices (samples , k = num_samples )
68
+ else :
69
+ samples = random .sample (samples , num_samples )
70
+
71
+ print (f"Selected { len (samples )} samples for comparison" )
72
+ return samples
73
+
74
+
75
+ def process_single_prompt (sample : Dict [str , any ], llm , hf_model , processor , sampling_params , device , args ):
76
+ """Process a single prompt with image and return comparison results."""
77
+ prompt = sample ['prompt' ]
78
+ image = sample ['image' ] # Already a PIL Image object
79
+
80
+ print (f"\n { '=' * 80 } " )
81
+ print (f"PROMPT: { prompt [:100 ]} ..." if len (prompt ) > 100 else f"PROMPT: { prompt } " )
82
+ print (f"IMAGE: { image .size } { image .mode } " )
83
+
84
+ # Generate with vLLM
85
+ print ("\n === vLLM Generation ===" )
86
+ # Convert image to base64 data URL
87
+ image_data_url = image_to_base64_data_url (image )
88
+
89
+ # For VLMs, vLLM expects the message format with image
90
+ messages = [{
91
+ "role" : "user" ,
92
+ "content" : [
93
+ {"type" : "image_url" , "image_url" : {"url" : image_data_url }},
94
+ {"type" : "text" , "text" : prompt }
95
+ ]
96
+ }]
97
+ outputs = llm .chat (messages , sampling_params )
98
+ output = outputs [0 ]
99
+
100
+ # Extract prompt and generated token IDs
101
+ prompt_token_ids = output .prompt_token_ids
102
+ generated_token_ids = output .outputs [0 ].token_ids
103
+
104
+ print (f"Prompt tokens ({ len (prompt_token_ids )} ): { prompt_token_ids [:10 ]} ..." if len (prompt_token_ids ) > 10 else f"Prompt tokens ({ len (prompt_token_ids )} ): { prompt_token_ids } " )
105
+ print (f"Generated tokens ({ len (generated_token_ids )} ): { generated_token_ids } " )
106
+ print (f"Generated text: { processor .decode (generated_token_ids , skip_special_tokens = True )} " )
107
+
108
+ # Create input tensor from concatenated token IDs
109
+ # input_ids = torch.tensor([all_token_ids], device=device) # Not needed for HF VLM models
110
+
111
+ # HuggingFace forward pass
112
+ print ("\n === HuggingFace Forward Pass ===" )
113
+ # Prepare inputs for HF model
114
+ conversation = [
115
+ {
116
+ "role" : "user" ,
117
+ "content" : [
118
+ {"type" : "image" },
119
+ {"type" : "text" , "text" : prompt }
120
+ ]
121
+ }
122
+ ]
123
+ text_prompt = processor .apply_chat_template (conversation , add_generation_prompt = True , tokenize = False )
124
+ inputs = processor (
125
+ text = [text_prompt ],
126
+ images = [image ],
127
+ return_tensors = "pt"
128
+ ).to (device )
129
+
130
+ print ("INPUTS" , inputs )
131
+
132
+ # Concatenate the generated tokens to the input_ids
133
+ generated_ids_tensor = torch .tensor ([generated_token_ids ], device = device )
134
+ inputs ["input_ids" ] = torch .cat ([inputs ["input_ids" ], generated_ids_tensor ], dim = 1 )
135
+ inputs ["attention_mask" ] = torch .ones_like (inputs ["input_ids" ])
136
+
137
+ with torch .no_grad ():
138
+ outputs_hf = hf_model (** inputs )
139
+ logits = outputs_hf .logits [0 ] # [seq_len, vocab_size]
140
+
141
+ # Token-by-token comparison
142
+ print (f"\n { 'Pos' :>4} { 'Token ID' :>8} { 'Token' :>20} { 'Type' :>8} { 'vLLM Prob' :>12} { 'HF Argmax' :>10} { 'HF Prob' :>12} { 'Match' :>6} { 'HF Token' :>20} " )
143
+ print ("-" * 125 )
144
+
145
+ # Get vLLM logprobs for generated tokens
146
+ vllm_logprobs = output .outputs [0 ].logprobs
147
+
148
+ # Track mismatch info
149
+ first_mismatch_idx = None
150
+ max_prob_diff = 0.0
151
+
152
+ # Get all token IDs from the HF model's input
153
+ all_token_ids = inputs ["input_ids" ][0 ].tolist ()
154
+
155
+ # Compare ALL tokens (prompt + generated)
156
+ for pos , token_id in enumerate (all_token_ids ):
157
+ token_str = processor .decode ([token_id ], skip_special_tokens = False ).replace ('\n ' , '\\ n' ).replace ('\r ' , '\\ r' )
158
+
159
+ # Determine if this is a prompt or generated token
160
+ is_prompt = pos < len (prompt_token_ids )
161
+ token_type = "prompt" if is_prompt else "gen"
162
+
163
+ # vLLM probability (only for generated tokens)
164
+ vllm_prob_str = "N/A"
165
+ vllm_prob = None
166
+ if not is_prompt :
167
+ gen_idx = pos - len (prompt_token_ids )
168
+ if vllm_logprobs and gen_idx < len (vllm_logprobs ):
169
+ # vLLM logprobs is a list of dicts mapping token_id to logprob
170
+ token_logprobs = vllm_logprobs [gen_idx ]
171
+ if token_logprobs and token_id in token_logprobs :
172
+ # Convert logprob to probability
173
+ vllm_prob = torch .exp (torch .tensor (token_logprobs [token_id ].logprob )).item ()
174
+ vllm_prob_str = f"{ vllm_prob :12.6f} "
175
+
176
+ # HF prediction - only for generated tokens (skip prompt tokens entirely)
177
+ if pos > 0 and not is_prompt :
178
+ hf_logits_at_pos = logits [pos - 1 ]
179
+ hf_probs = torch .softmax (hf_logits_at_pos , dim = - 1 )
180
+ hf_argmax = torch .argmax (hf_logits_at_pos ).item ()
181
+ hf_prob = hf_probs [token_id ].item ()
182
+
183
+ # Check if predictions match
184
+ match = "✓" if token_id == hf_argmax else "✗"
185
+
186
+ # Track first mismatch and probability difference
187
+ if token_id != hf_argmax :
188
+ if first_mismatch_idx is None :
189
+ first_mismatch_idx = pos - len (prompt_token_ids )
190
+
191
+ # Calculate probability difference
192
+ if vllm_prob is not None :
193
+ prob_diff = abs (vllm_prob - hf_prob )
194
+ max_prob_diff = max (max_prob_diff , prob_diff )
195
+
196
+ # Decode HF argmax token (only show if mismatch)
197
+ hf_token_str = ""
198
+ if token_id != hf_argmax :
199
+ hf_token_str = processor .decode ([hf_argmax ], skip_special_tokens = False ).replace ('\n ' , '\\ n' ).replace ('\r ' , '\\ r' )
200
+
201
+ print (f"{ pos :>4} { token_id :>8} { token_str :>20} { token_type :>8} { vllm_prob_str :>12} { hf_argmax :>10} { hf_prob :>12.6f} { match :>6} { hf_token_str :>20} " )
202
+ else :
203
+ # Prompt tokens or first token - no HF comparison
204
+ print (f"{ pos :>4} { token_id :>8} { token_str :>20} { token_type :>8} { vllm_prob_str :>12} { '' :>10} { '' :>12} { '' :>6} { '' :<20} " )
205
+
206
+ # Summary
207
+ print (f"\n === Summary ===" )
208
+ print (f"Total tokens generated: { len (generated_token_ids )} " )
209
+
210
+ # Calculate match rate
211
+ matches = 0
212
+ for i , token_id in enumerate (generated_token_ids ):
213
+ pos = len (prompt_token_ids ) + i
214
+ hf_logits_at_pos = logits [pos - 1 ]
215
+ hf_argmax = torch .argmax (hf_logits_at_pos ).item ()
216
+ if token_id == hf_argmax :
217
+ matches += 1
218
+
219
+ match_rate = matches / len (generated_token_ids ) * 100 if generated_token_ids else 0
220
+ print (f"Token match rate: { matches } /{ len (generated_token_ids )} ({ match_rate :.1f} %)" )
221
+
222
+ # Report first mismatch index
223
+ if first_mismatch_idx is not None :
224
+ print (f"First mismatch at generation index: { first_mismatch_idx } " )
225
+ print (f"Max probability difference: { max_prob_diff :.6f} " )
226
+ else :
227
+ print ("No mismatches found in generated tokens" )
228
+
229
+ return {
230
+ 'first_mismatch_idx' : first_mismatch_idx ,
231
+ 'max_prob_diff' : max_prob_diff ,
232
+ 'match_rate' : match_rate
233
+ }
234
+
235
+
236
+ def main ():
237
+ parser = argparse .ArgumentParser (description = "Batch compare VLM inference between vLLM and HuggingFace" )
238
+ parser .add_argument ("--model" , type = str , default = "Qwen/Qwen2.5-VL-7B-Instruct" ,
239
+ help = "Model name or path" )
240
+ parser .add_argument ("--max-tokens" , type = int , default = 20 ,
241
+ help = "Maximum tokens to generate per prompt" )
242
+ parser .add_argument ("--temperature" , type = float , default = 0.0 ,
243
+ help = "Sampling temperature" )
244
+ parser .add_argument ("--num-prompts" , type = int , default = 100 ,
245
+ help = "Number of prompts to load from WildVision" )
246
+ parser .add_argument ("--prob-threshold" , type = float , default = 0.20 ,
247
+ help = "Probability difference threshold to stop" )
248
+ parser .add_argument ("--seed" , type = int , default = 42 ,
249
+ help = "Random seed for prompt selection" )
250
+ args = parser .parse_args ()
251
+
252
+ print (f"Model: { args .model } " )
253
+ print (f"Max tokens: { args .max_tokens } " )
254
+ print (f"Temperature: { args .temperature } " )
255
+ print (f"Probability threshold: { args .prob_threshold } " )
256
+ print (f"Loading { args .num_prompts } samples from WildVision-bench\n " )
257
+
258
+ # Load prompts and images
259
+ samples = load_wildvision_prompts (num_samples = args .num_prompts , seed = args .seed )
260
+
261
+ # Create vLLM engine
262
+ print ("\n === Creating vLLM Engine ===" )
263
+ llm = LLM (model = args .model , trust_remote_code = True , gpu_memory_utilization = 0.5 )
264
+ sampling_params = SamplingParams (
265
+ temperature = args .temperature ,
266
+ max_tokens = args .max_tokens ,
267
+ logprobs = 1 # Get top-1 logprobs
268
+ )
269
+
270
+ # Get processor (VLMs use processor instead of tokenizer)
271
+ # processor = llm.get_tokenizer() # Not needed, we get it later
272
+
273
+ # Clean up vLLM before loading HF model
274
+ del llm
275
+ gc .collect ()
276
+ torch .cuda .empty_cache ()
277
+
278
+ # Load HuggingFace model and processor
279
+ print ("\n === Loading HuggingFace Model ===" )
280
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
281
+ processor_hf = AutoProcessor .from_pretrained (args .model , trust_remote_code = True )
282
+ hf_model = AutoModelForVision2Seq .from_pretrained (
283
+ args .model ,
284
+ trust_remote_code = True ,
285
+ torch_dtype = torch .float16 ,
286
+ device_map = "auto"
287
+ )
288
+ hf_model .eval ()
289
+
290
+ # Process samples until finding significant mismatch
291
+ print ("\n === Processing Samples ===" )
292
+ for i , sample in enumerate (samples ):
293
+ print (f"\n \n { '#' * 80 } " )
294
+ print (f"### Processing sample { i + 1 } /{ len (samples )} " )
295
+ print (f"{ '#' * 80 } " )
296
+
297
+ # Recreate vLLM for each prompt
298
+ llm = LLM (model = args .model , trust_remote_code = True , gpu_memory_utilization = 0.5 )
299
+
300
+ # Process single sample
301
+ result = process_single_prompt (sample , llm , hf_model , processor_hf , sampling_params , device , args )
302
+
303
+ # Clean up vLLM after each prompt
304
+ del llm
305
+ gc .collect ()
306
+ torch .cuda .empty_cache ()
307
+
308
+ # Check if we found significant mismatch
309
+ if result ['first_mismatch_idx' ] is not None and result ['max_prob_diff' ] > args .prob_threshold :
310
+ print (f"\n \n { '*' * 80 } " )
311
+ print (f"*** FOUND SIGNIFICANT MISMATCH ***" )
312
+ print (f"*** Max probability difference: { result ['max_prob_diff' ]:.6f} > { args .prob_threshold } ***" )
313
+ print (f"*** Stopping after sample { i + 1 } /{ len (samples )} ***" )
314
+ print (f"{ '*' * 80 } " )
315
+ break
316
+ else :
317
+ print (f"\n \n { '=' * 80 } " )
318
+ print (f"=== Processed all { len (samples )} samples without finding significant mismatch ===" )
319
+ print (f"{ '=' * 80 } " )
320
+
321
+
322
+ if __name__ == "__main__" :
323
+ main ()
0 commit comments