Skip to content

Commit 4b0960b

Browse files
committed
Test
1 parent ee69faa commit 4b0960b

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

olmocr/train/compress_checkpoint.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from llmcompressor import oneshot
3232
from PIL import Image
3333
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor
34-
from qwen_vl_utils import process_vision_info
3534

3635
from olmocr.s3_utils import parse_s3_path
3736
from olmocr.pipeline import build_page_query
@@ -80,24 +79,39 @@ async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> List[d
8079
# Get first page of each PDF (page 0)
8180
query = await build_page_query(pdf_path, page=0, target_longest_image_dim=1024)
8281

83-
# Extract the image and text from the query
82+
# Extract the messages
8483
messages = query["messages"]
8584

85+
# Extract images from the message content
86+
images = []
87+
for message in messages:
88+
if message.get("role") == "user":
89+
content = message.get("content", [])
90+
for item in content:
91+
if item.get("type") == "image_url":
92+
image_url = item["image_url"]["url"]
93+
# Extract base64 image data
94+
if image_url.startswith("data:image"):
95+
base64_str = image_url.split(",")[1]
96+
image_bytes = base64.b64decode(base64_str)
97+
image = Image.open(BytesIO(image_bytes))
98+
images.append(image)
99+
100+
# Apply chat template to get the text
86101
text = processor.apply_chat_template(
87102
messages, tokenize=False, add_generation_prompt=True
88103
)
89104

90-
image_inputs, video_inputs = process_vision_info(messages)
91-
92-
# tokenize
93-
return processor(
105+
# Process with tokenizer
106+
inputs = processor(
94107
text=[text],
95-
images=image_inputs,
96-
videos=video_inputs,
108+
images=images if images else None,
97109
padding=False,
98110
max_length=8192,
99111
truncation=True,
100112
)
113+
114+
dataset.append(inputs)
101115

102116
return dataset
103117

0 commit comments

Comments
 (0)