Skip to content

Commit 01360ba

Browse files
committed
Compressor script
1 parent 1ede76d commit 01360ba

File tree

1 file changed

+44
-14
lines changed

1 file changed

+44
-14
lines changed

olmocr/train/compress_checkpoint.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717
import os
1818
import shutil
1919
import tempfile
20-
from typing import Optional, Tuple
20+
from typing import Optional, Tuple, Union
2121

2222
import boto3
2323
import torch
2424
from llmcompressor import oneshot
2525
from llmcompressor.modifiers.quantization import QuantizationModifier
26-
from smart_open import smart_open
27-
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration
26+
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
2827

2928
from olmocr.s3_utils import parse_s3_path
3029

@@ -74,7 +73,7 @@ def upload_local_to_s3(local_dir: str, bucket: str, prefix: str) -> None:
7473
print(f" Uploaded {rel_path}")
7574

7675

77-
def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGeneration, AutoTokenizer, Optional[str]]:
76+
def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]:
7877
"""Load model and tokenizer from source path (local or S3)."""
7978
if is_s3_path(source_path):
8079
# Download from S3 to temporary directory
@@ -86,12 +85,48 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen
8685
model_path = source_path
8786
temp_dir = None
8887

88+
# Read config to determine model architecture
89+
config_path = os.path.join(model_path, "config.json")
90+
with open(config_path, "r") as f:
91+
config = json.load(f)
92+
93+
# Get model name from config
94+
model_name = config.get("name_or_path", "")
95+
8996
print(f"Loading model from {model_path}...")
90-
model = Qwen2VLForConditionalGeneration.from_pretrained(
91-
model_path,
92-
device_map="auto",
93-
torch_dtype="auto"
94-
)
97+
98+
# Load appropriate model class based on name
99+
if "Qwen2.5-VL" in model_name:
100+
print("Detected Qwen2.5-VL model")
101+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
102+
model_path,
103+
device_map="auto",
104+
torch_dtype="auto"
105+
)
106+
elif "Qwen2-VL" in model_name:
107+
print("Detected Qwen2-VL model")
108+
model = Qwen2VLForConditionalGeneration.from_pretrained(
109+
model_path,
110+
device_map="auto",
111+
torch_dtype="auto"
112+
)
113+
else:
114+
# Default to checking architectures list
115+
architectures = config.get("architectures", [])
116+
if "Qwen2_5_VLForConditionalGeneration" in architectures:
117+
print("Detected Qwen2.5-VL model from architectures")
118+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
119+
model_path,
120+
device_map="auto",
121+
torch_dtype="auto"
122+
)
123+
else:
124+
print("Detected Qwen2-VL model from architectures")
125+
model = Qwen2VLForConditionalGeneration.from_pretrained(
126+
model_path,
127+
device_map="auto",
128+
torch_dtype="auto"
129+
)
95130

96131
print(f"Loading tokenizer from {model_path}...")
97132
tokenizer = AutoTokenizer.from_pretrained(model_path)
@@ -101,11 +136,6 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen
101136

102137
def compress_checkpoint(source_path: str, dest_path: str) -> None:
103138
"""Compress OlmOCR checkpoint using FP8 quantization."""
104-
# First, validate the source checkpoint
105-
config_path = os.path.join(source_path, "config.json")
106-
if is_s3_path(source_path):
107-
config_path = f"{source_path}/config.json"
108-
109139
# Load model and tokenizer
110140
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
111141

0 commit comments

Comments
 (0)