17
17
import os
18
18
import shutil
19
19
import tempfile
20
- from typing import Optional , Tuple
20
+ from typing import Optional , Tuple , Union
21
21
22
22
import boto3
23
23
import torch
24
24
from llmcompressor import oneshot
25
25
from llmcompressor .modifiers .quantization import QuantizationModifier
26
- from smart_open import smart_open
27
- from transformers import AutoTokenizer , Qwen2VLForConditionalGeneration
26
+ from transformers import AutoTokenizer , Qwen2VLForConditionalGeneration , Qwen2_5_VLForConditionalGeneration
28
27
29
28
from olmocr .s3_utils import parse_s3_path
30
29
@@ -74,7 +73,7 @@ def upload_local_to_s3(local_dir: str, bucket: str, prefix: str) -> None:
74
73
print (f" Uploaded { rel_path } " )
75
74
76
75
77
- def load_model_and_tokenizer (source_path : str ) -> Tuple [Qwen2VLForConditionalGeneration , AutoTokenizer , Optional [str ]]:
76
+ def load_model_and_tokenizer (source_path : str ) -> Tuple [Union [ Qwen2VLForConditionalGeneration , Qwen2_5_VLForConditionalGeneration ] , AutoTokenizer , Optional [str ]]:
78
77
"""Load model and tokenizer from source path (local or S3)."""
79
78
if is_s3_path (source_path ):
80
79
# Download from S3 to temporary directory
@@ -86,12 +85,48 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen
86
85
model_path = source_path
87
86
temp_dir = None
88
87
88
+ # Read config to determine model architecture
89
+ config_path = os .path .join (model_path , "config.json" )
90
+ with open (config_path , "r" ) as f :
91
+ config = json .load (f )
92
+
93
+ # Get model name from config
94
+ model_name = config .get ("name_or_path" , "" )
95
+
89
96
print (f"Loading model from { model_path } ..." )
90
- model = Qwen2VLForConditionalGeneration .from_pretrained (
91
- model_path ,
92
- device_map = "auto" ,
93
- torch_dtype = "auto"
94
- )
97
+
98
+ # Load appropriate model class based on name
99
+ if "Qwen2.5-VL" in model_name :
100
+ print ("Detected Qwen2.5-VL model" )
101
+ model = Qwen2_5_VLForConditionalGeneration .from_pretrained (
102
+ model_path ,
103
+ device_map = "auto" ,
104
+ torch_dtype = "auto"
105
+ )
106
+ elif "Qwen2-VL" in model_name :
107
+ print ("Detected Qwen2-VL model" )
108
+ model = Qwen2VLForConditionalGeneration .from_pretrained (
109
+ model_path ,
110
+ device_map = "auto" ,
111
+ torch_dtype = "auto"
112
+ )
113
+ else :
114
+ # Default to checking architectures list
115
+ architectures = config .get ("architectures" , [])
116
+ if "Qwen2_5_VLForConditionalGeneration" in architectures :
117
+ print ("Detected Qwen2.5-VL model from architectures" )
118
+ model = Qwen2_5_VLForConditionalGeneration .from_pretrained (
119
+ model_path ,
120
+ device_map = "auto" ,
121
+ torch_dtype = "auto"
122
+ )
123
+ else :
124
+ print ("Detected Qwen2-VL model from architectures" )
125
+ model = Qwen2VLForConditionalGeneration .from_pretrained (
126
+ model_path ,
127
+ device_map = "auto" ,
128
+ torch_dtype = "auto"
129
+ )
95
130
96
131
print (f"Loading tokenizer from { model_path } ..." )
97
132
tokenizer = AutoTokenizer .from_pretrained (model_path )
@@ -101,11 +136,6 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen
101
136
102
137
def compress_checkpoint (source_path : str , dest_path : str ) -> None :
103
138
"""Compress OlmOCR checkpoint using FP8 quantization."""
104
- # First, validate the source checkpoint
105
- config_path = os .path .join (source_path , "config.json" )
106
- if is_s3_path (source_path ):
107
- config_path = f"{ source_path } /config.json"
108
-
109
139
# Load model and tokenizer
110
140
model , tokenizer , temp_source_dir = load_model_and_tokenizer (source_path )
111
141
0 commit comments