Skip to content

Commit 78fab24

Browse files
committed
add phi3 vision example
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 03e2177 commit 78fab24

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoProcessor
3+
4+
from llmcompressor.modifiers.quantization import GPTQModifier
5+
from llmcompressor.transformers import oneshot
6+
from llmcompressor.transformers.utils.data_collator import phi3_vision_data_collator
7+
8+
# Load model.
9+
model_id = "microsoft/Phi-3-vision-128k-instruct"
10+
model = AutoModelForCausalLM.from_pretrained(
11+
model_id,
12+
device_map="auto",
13+
torch_dtype="auto",
14+
trust_remote_code=True,
15+
_attn_implementation="eager",
16+
)
17+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
18+
processor.chat_template = processor.tokenizer.chat_template
19+
20+
# Oneshot arguments
21+
DATASET_ID = "lmms-lab/flickr30k"
22+
DATASET_SPLIT = "test[:512]"
23+
NUM_CALIBRATION_SAMPLES = 512
24+
MAX_SEQUENCE_LENGTH = 2048
25+
26+
# Load dataset and preprocess.
27+
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
28+
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
29+
30+
31+
# Apply chat template
32+
def preprocess(example):
33+
messages = [{"role": "user", "content": "<|image_1|>\nWhat does the image show?"}]
34+
return {
35+
"text": processor.apply_chat_template(
36+
messages,
37+
add_generation_prompt=True,
38+
),
39+
"images": example["image"],
40+
}
41+
42+
43+
ds = ds.map(preprocess)
44+
45+
46+
# # Tokenize inputs.
47+
def tokenize(sample):
48+
return processor(
49+
text=sample["text"],
50+
images=sample["images"],
51+
padding=False,
52+
max_length=MAX_SEQUENCE_LENGTH,
53+
truncation=True,
54+
)
55+
56+
57+
# long data lengths produced by the phi3_vision processor
58+
# can lead to integer overflows when mapping, avoid with writer_batch_size
59+
ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names)
60+
61+
62+
# Recipe
63+
recipe = [
64+
GPTQModifier(
65+
targets="Linear",
66+
scheme="W4A16",
67+
sequential_targets=["Phi3DecoderLayer"],
68+
ignore=["lm_head", "re:model.vision_embed_tokens.*"],
69+
),
70+
]
71+
72+
# Perform oneshot
73+
oneshot(
74+
model=model,
75+
dataset=ds,
76+
recipe=recipe,
77+
max_seq_length=MAX_SEQUENCE_LENGTH,
78+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
79+
trust_remote_code_model=True,
80+
data_collator=phi3_vision_data_collator,
81+
)
82+
83+
# Confirm generations of the quantized model look sane.
84+
print("========== SAMPLE GENERATION ==============")
85+
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
86+
output = model.generate(input_ids, max_new_tokens=20)
87+
print(processor.decode(output[0]))
88+
print("==========================================")
89+
90+
# Save to disk compressed.
91+
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
92+
model.save_pretrained(SAVE_DIR, save_compressed=True)
93+
processor.save_pretrained(SAVE_DIR)

src/llmcompressor/transformers/utils/data_collator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"pixtral_data_collator",
66
"llava_data_collator",
77
"qwen2_vl_data_collator",
8+
"phi3_vision_data_collator",
89
]
910

1011

@@ -46,3 +47,13 @@ def qwen2_vl_data_collator(batch):
4647
"pixel_values": torch.tensor(batch[0]["pixel_values"]),
4748
"image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]),
4849
}
50+
51+
52+
def phi3_vision_data_collator(batch):
53+
assert len(batch) == 1
54+
return {
55+
"input_ids": torch.LongTensor(batch[0]["input_ids"]),
56+
"attention_mask": torch.tensor(batch[0]["attention_mask"]),
57+
"pixel_values": torch.tensor(batch[0]["pixel_values"]),
58+
"image_sizes": torch.tensor(batch[0]["image_sizes"]),
59+
}

0 commit comments

Comments
 (0)