Skip to content

Commit 2460895

Browse files
committed
Working on comparing to vllm
1 parent e6c9823 commit 2460895

File tree

2 files changed

+324
-1
lines changed

2 files changed

+324
-1
lines changed
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Batch VLM inference comparison between vLLM and HuggingFace.
4+
Processes prompts and images from WildVision-bench until finding significant mismatch.
5+
"""
6+
7+
import argparse
8+
import gc
9+
import torch
10+
from vllm import LLM, SamplingParams
11+
from transformers import AutoProcessor, AutoModelForVision2Seq
12+
from datasets import load_dataset
13+
import random
14+
import numpy as np
15+
from typing import List, Dict
16+
import base64
17+
from io import BytesIO
18+
19+
20+
def image_to_base64_data_url(image):
21+
"""Convert PIL image to base64 data URL."""
22+
buffered = BytesIO()
23+
image.save(buffered, format="PNG")
24+
img_str = base64.b64encode(buffered.getvalue()).decode()
25+
return f"data:image/png;base64,{img_str}"
26+
27+
28+
def load_wildvision_prompts(num_samples: int = 100, seed: int = 42, max_length: int = 2048) -> List[Dict[str, str]]:
29+
"""Load prompts and images from WildVision-bench dataset with fixed random seed."""
30+
print(f"Loading WildVision-bench dataset with {num_samples} samples and seed {seed}")
31+
32+
# Set random seed for reproducibility
33+
random.seed(seed)
34+
np.random.seed(seed)
35+
36+
# Load dataset
37+
dataset = load_dataset("WildVision/wildvision-bench", "vision_bench_0701", split="test", streaming=True)
38+
39+
# Collect prompts and images
40+
samples = []
41+
examined = 0
42+
for example in dataset:
43+
examined += 1
44+
if len(samples) >= num_samples * 2: # Collect extra to allow filtering
45+
break
46+
47+
# Extract prompt and image from the example
48+
prompt = example.get('instruction', '').strip()
49+
image = example.get('image', None)
50+
51+
# Filter by prompt length and ensure we have both prompt and image
52+
if prompt and image and 10 < len(prompt) <= max_length:
53+
samples.append({
54+
'prompt': prompt,
55+
'image': image # This is already a PIL Image object
56+
})
57+
58+
# Stop if we've examined too many without finding enough
59+
if examined > num_samples * 10:
60+
break
61+
62+
# Randomly sample exactly num_samples
63+
if len(samples) < num_samples:
64+
print(f"Only found {len(samples)} valid samples out of {examined} examined")
65+
if len(samples) == 0:
66+
raise ValueError("No valid samples found in dataset")
67+
samples = random.choices(samples, k=num_samples)
68+
else:
69+
samples = random.sample(samples, num_samples)
70+
71+
print(f"Selected {len(samples)} samples for comparison")
72+
return samples
73+
74+
75+
def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, sampling_params, device, args):
76+
"""Process a single prompt with image and return comparison results."""
77+
prompt = sample['prompt']
78+
image = sample['image'] # Already a PIL Image object
79+
80+
print(f"\n{'='*80}")
81+
print(f"PROMPT: {prompt[:100]}..." if len(prompt) > 100 else f"PROMPT: {prompt}")
82+
print(f"IMAGE: {image.size} {image.mode}")
83+
84+
# Generate with vLLM
85+
print("\n=== vLLM Generation ===")
86+
# Convert image to base64 data URL
87+
image_data_url = image_to_base64_data_url(image)
88+
89+
# For VLMs, vLLM expects the message format with image
90+
messages = [{
91+
"role": "user",
92+
"content": [
93+
{"type": "image_url", "image_url": {"url": image_data_url}},
94+
{"type": "text", "text": prompt}
95+
]
96+
}]
97+
outputs = llm.chat(messages, sampling_params)
98+
output = outputs[0]
99+
100+
# Extract prompt and generated token IDs
101+
prompt_token_ids = output.prompt_token_ids
102+
generated_token_ids = output.outputs[0].token_ids
103+
104+
print(f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids[:10]}..." if len(prompt_token_ids) > 10 else f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids}")
105+
print(f"Generated tokens ({len(generated_token_ids)}): {generated_token_ids}")
106+
print(f"Generated text: {processor.decode(generated_token_ids, skip_special_tokens=True)}")
107+
108+
# Create input tensor from concatenated token IDs
109+
# input_ids = torch.tensor([all_token_ids], device=device) # Not needed for HF VLM models
110+
111+
# HuggingFace forward pass
112+
print("\n=== HuggingFace Forward Pass ===")
113+
# Prepare inputs for HF model
114+
conversation = [
115+
{
116+
"role": "user",
117+
"content": [
118+
{"type": "image"},
119+
{"type": "text", "text": prompt}
120+
]
121+
}
122+
]
123+
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
124+
inputs = processor(
125+
text=[text_prompt],
126+
images=[image],
127+
return_tensors="pt"
128+
).to(device)
129+
130+
print("INPUTS", inputs)
131+
132+
# Concatenate the generated tokens to the input_ids
133+
generated_ids_tensor = torch.tensor([generated_token_ids], device=device)
134+
inputs["input_ids"] = torch.cat([inputs["input_ids"], generated_ids_tensor], dim=1)
135+
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"])
136+
137+
with torch.no_grad():
138+
outputs_hf = hf_model(**inputs)
139+
logits = outputs_hf.logits[0] # [seq_len, vocab_size]
140+
141+
# Token-by-token comparison
142+
print(f"\n{'Pos':>4} {'Token ID':>8} {'Token':>20} {'Type':>8} {'vLLM Prob':>12} {'HF Argmax':>10} {'HF Prob':>12} {'Match':>6} {'HF Token':>20}")
143+
print("-" * 125)
144+
145+
# Get vLLM logprobs for generated tokens
146+
vllm_logprobs = output.outputs[0].logprobs
147+
148+
# Track mismatch info
149+
first_mismatch_idx = None
150+
max_prob_diff = 0.0
151+
152+
# Get all token IDs from the HF model's input
153+
all_token_ids = inputs["input_ids"][0].tolist()
154+
155+
# Compare ALL tokens (prompt + generated)
156+
for pos, token_id in enumerate(all_token_ids):
157+
token_str = processor.decode([token_id], skip_special_tokens=False).replace('\n', '\\n').replace('\r', '\\r')
158+
159+
# Determine if this is a prompt or generated token
160+
is_prompt = pos < len(prompt_token_ids)
161+
token_type = "prompt" if is_prompt else "gen"
162+
163+
# vLLM probability (only for generated tokens)
164+
vllm_prob_str = "N/A"
165+
vllm_prob = None
166+
if not is_prompt:
167+
gen_idx = pos - len(prompt_token_ids)
168+
if vllm_logprobs and gen_idx < len(vllm_logprobs):
169+
# vLLM logprobs is a list of dicts mapping token_id to logprob
170+
token_logprobs = vllm_logprobs[gen_idx]
171+
if token_logprobs and token_id in token_logprobs:
172+
# Convert logprob to probability
173+
vllm_prob = torch.exp(torch.tensor(token_logprobs[token_id].logprob)).item()
174+
vllm_prob_str = f"{vllm_prob:12.6f}"
175+
176+
# HF prediction - only for generated tokens (skip prompt tokens entirely)
177+
if pos > 0 and not is_prompt:
178+
hf_logits_at_pos = logits[pos - 1]
179+
hf_probs = torch.softmax(hf_logits_at_pos, dim=-1)
180+
hf_argmax = torch.argmax(hf_logits_at_pos).item()
181+
hf_prob = hf_probs[token_id].item()
182+
183+
# Check if predictions match
184+
match = "✓" if token_id == hf_argmax else "✗"
185+
186+
# Track first mismatch and probability difference
187+
if token_id != hf_argmax:
188+
if first_mismatch_idx is None:
189+
first_mismatch_idx = pos - len(prompt_token_ids)
190+
191+
# Calculate probability difference
192+
if vllm_prob is not None:
193+
prob_diff = abs(vllm_prob - hf_prob)
194+
max_prob_diff = max(max_prob_diff, prob_diff)
195+
196+
# Decode HF argmax token (only show if mismatch)
197+
hf_token_str = ""
198+
if token_id != hf_argmax:
199+
hf_token_str = processor.decode([hf_argmax], skip_special_tokens=False).replace('\n', '\\n').replace('\r', '\\r')
200+
201+
print(f"{pos:>4} {token_id:>8} {token_str:>20} {token_type:>8} {vllm_prob_str:>12} {hf_argmax:>10} {hf_prob:>12.6f} {match:>6} {hf_token_str:>20}")
202+
else:
203+
# Prompt tokens or first token - no HF comparison
204+
print(f"{pos:>4} {token_id:>8} {token_str:>20} {token_type:>8} {vllm_prob_str:>12} {'':>10} {'':>12} {'':>6} {'':<20}")
205+
206+
# Summary
207+
print(f"\n=== Summary ===")
208+
print(f"Total tokens generated: {len(generated_token_ids)}")
209+
210+
# Calculate match rate
211+
matches = 0
212+
for i, token_id in enumerate(generated_token_ids):
213+
pos = len(prompt_token_ids) + i
214+
hf_logits_at_pos = logits[pos - 1]
215+
hf_argmax = torch.argmax(hf_logits_at_pos).item()
216+
if token_id == hf_argmax:
217+
matches += 1
218+
219+
match_rate = matches / len(generated_token_ids) * 100 if generated_token_ids else 0
220+
print(f"Token match rate: {matches}/{len(generated_token_ids)} ({match_rate:.1f}%)")
221+
222+
# Report first mismatch index
223+
if first_mismatch_idx is not None:
224+
print(f"First mismatch at generation index: {first_mismatch_idx}")
225+
print(f"Max probability difference: {max_prob_diff:.6f}")
226+
else:
227+
print("No mismatches found in generated tokens")
228+
229+
return {
230+
'first_mismatch_idx': first_mismatch_idx,
231+
'max_prob_diff': max_prob_diff,
232+
'match_rate': match_rate
233+
}
234+
235+
236+
def main():
237+
parser = argparse.ArgumentParser(description="Batch compare VLM inference between vLLM and HuggingFace")
238+
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct",
239+
help="Model name or path")
240+
parser.add_argument("--max-tokens", type=int, default=20,
241+
help="Maximum tokens to generate per prompt")
242+
parser.add_argument("--temperature", type=float, default=0.0,
243+
help="Sampling temperature")
244+
parser.add_argument("--num-prompts", type=int, default=100,
245+
help="Number of prompts to load from WildVision")
246+
parser.add_argument("--prob-threshold", type=float, default=0.20,
247+
help="Probability difference threshold to stop")
248+
parser.add_argument("--seed", type=int, default=42,
249+
help="Random seed for prompt selection")
250+
args = parser.parse_args()
251+
252+
print(f"Model: {args.model}")
253+
print(f"Max tokens: {args.max_tokens}")
254+
print(f"Temperature: {args.temperature}")
255+
print(f"Probability threshold: {args.prob_threshold}")
256+
print(f"Loading {args.num_prompts} samples from WildVision-bench\n")
257+
258+
# Load prompts and images
259+
samples = load_wildvision_prompts(num_samples=args.num_prompts, seed=args.seed)
260+
261+
# Create vLLM engine
262+
print("\n=== Creating vLLM Engine ===")
263+
llm = LLM(model=args.model, trust_remote_code=True, gpu_memory_utilization=0.5)
264+
sampling_params = SamplingParams(
265+
temperature=args.temperature,
266+
max_tokens=args.max_tokens,
267+
logprobs=1 # Get top-1 logprobs
268+
)
269+
270+
# Get processor (VLMs use processor instead of tokenizer)
271+
# processor = llm.get_tokenizer() # Not needed, we get it later
272+
273+
# Clean up vLLM before loading HF model
274+
del llm
275+
gc.collect()
276+
torch.cuda.empty_cache()
277+
278+
# Load HuggingFace model and processor
279+
print("\n=== Loading HuggingFace Model ===")
280+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
281+
processor_hf = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
282+
hf_model = AutoModelForVision2Seq.from_pretrained(
283+
args.model,
284+
trust_remote_code=True,
285+
torch_dtype=torch.float16,
286+
device_map="auto"
287+
)
288+
hf_model.eval()
289+
290+
# Process samples until finding significant mismatch
291+
print("\n=== Processing Samples ===")
292+
for i, sample in enumerate(samples):
293+
print(f"\n\n{'#'*80}")
294+
print(f"### Processing sample {i+1}/{len(samples)}")
295+
print(f"{'#'*80}")
296+
297+
# Recreate vLLM for each prompt
298+
llm = LLM(model=args.model, trust_remote_code=True, gpu_memory_utilization=0.5)
299+
300+
# Process single sample
301+
result = process_single_prompt(sample, llm, hf_model, processor_hf, sampling_params, device, args)
302+
303+
# Clean up vLLM after each prompt
304+
del llm
305+
gc.collect()
306+
torch.cuda.empty_cache()
307+
308+
# Check if we found significant mismatch
309+
if result['first_mismatch_idx'] is not None and result['max_prob_diff'] > args.prob_threshold:
310+
print(f"\n\n{'*'*80}")
311+
print(f"*** FOUND SIGNIFICANT MISMATCH ***")
312+
print(f"*** Max probability difference: {result['max_prob_diff']:.6f} > {args.prob_threshold} ***")
313+
print(f"*** Stopping after sample {i+1}/{len(samples)} ***")
314+
print(f"{'*'*80}")
315+
break
316+
else:
317+
print(f"\n\n{'='*80}")
318+
print(f"=== Processed all {len(samples)} samples without finding significant mismatch ===")
319+
print(f"{'='*80}")
320+
321+
322+
if __name__ == "__main__":
323+
main()

scripts/compress_model.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ fi
8888
# Create Python script to run beaker experiment
8989
cat << 'EOF' > /tmp/run_compress_experiment.py
9090
import sys
91-
from beaker import Beaker, ExperimentSpec, TaskSpec, TaskContext, ResultSpec, TaskResources, ImageSource, Priority, Constraints, EnvVar
91+
from beaker import Beaker, ExperimentSpec, TaskSpec, TaskContext, ResultSpec, TaskResources, ImageSource, Priority, Constraints, EnvVar, DataMount
9292
9393
# Get parameters from command line
9494
image_tag = sys.argv[1]

0 commit comments

Comments
 (0)