1
+ #!/usr/bin/env python3
2
+ """
3
+ Prepares OlmOCR checkpoints for deployment by:
4
+ 1. Validating the model architecture
5
+ 2. Copying model files to destination (disk or S3)
6
+ 3. Downloading required tokenizer files from Hugging Face
7
+
8
+ Usage:
9
+ python prepare_olmocr_checkpoint.py <source_path> <destination_path>
10
+
11
+ source_path: Path to checkpoint (local or S3)
12
+ destination_path: Where to save prepared checkpoint (local or S3)
13
+ """
14
+
15
+ import argparse
16
+ import concurrent .futures
17
+ import json
18
+ import os
19
+ import shutil
20
+ import tempfile
21
+
22
+ import boto3
23
+ import requests
24
+ from smart_open import smart_open
25
+ from tqdm import tqdm
26
+
27
+ from olmocr .s3_utils import parse_s3_path
28
+
29
+ # Hugging Face model ID for tokenizer files
30
+ HF_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
31
+ HF_BASE_URL = f"https://huggingface.co/{ HF_MODEL_ID } /resolve/main"
32
+
33
+ # Required tokenizer files to download from Hugging Face
34
+ TOKENIZER_FILES = [
35
+ "chat_template.json" ,
36
+ "merges.txt" ,
37
+ "preprocessor_config.json" ,
38
+ "tokenizer.json" ,
39
+ "tokenizer_config.json" ,
40
+ "vocab.json"
41
+ ]
42
+
43
+ # Expected model architecture
44
+ EXPECTED_ARCHITECTURE = "Qwen2_5_VLForConditionalGeneration"
45
+
46
+ s3_client = boto3 .client ("s3" )
47
+
48
+
49
+ def is_s3_path (path : str ) -> bool :
50
+ """Check if a path is an S3 path."""
51
+ return path .startswith ("s3://" )
52
+
53
+
54
+ def download_file_from_hf (filename : str , destination_dir : str ) -> None :
55
+ """Download a file from Hugging Face model repository."""
56
+ url = f"{ HF_BASE_URL } /{ filename } "
57
+ local_path = os .path .join (destination_dir , filename )
58
+
59
+ print (f"Downloading { filename } from Hugging Face..." )
60
+ response = requests .get (url , stream = True )
61
+ response .raise_for_status ()
62
+
63
+ with open (local_path , "wb" ) as f :
64
+ for chunk in response .iter_content (chunk_size = 8192 ):
65
+ f .write (chunk )
66
+
67
+ print (f"Downloaded { filename } " )
68
+
69
+
70
+ def validate_checkpoint_architecture (config_path : str ) -> None :
71
+ """Validate that the checkpoint has the expected architecture."""
72
+ print (f"Validating checkpoint architecture from { config_path } ..." )
73
+
74
+ with smart_open (config_path , "r" ) as f :
75
+ config_data = json .load (f )
76
+
77
+ architectures = config_data .get ("architectures" , [])
78
+ if EXPECTED_ARCHITECTURE not in architectures :
79
+ raise ValueError (
80
+ f"Invalid model architecture. Expected '{ EXPECTED_ARCHITECTURE } ' "
81
+ f"but found: { architectures } "
82
+ )
83
+
84
+ print (f"✓ Valid architecture: { architectures } " )
85
+
86
+
87
+ def copy_local_to_local (source_dir : str , dest_dir : str ) -> None :
88
+ """Copy files from local directory to local directory."""
89
+ os .makedirs (dest_dir , exist_ok = True )
90
+
91
+ # Get list of files to copy
92
+ files_to_copy = []
93
+ for root , _ , files in os .walk (source_dir ):
94
+ for file in files :
95
+ src_path = os .path .join (root , file )
96
+ rel_path = os .path .relpath (src_path , source_dir )
97
+ files_to_copy .append ((src_path , os .path .join (dest_dir , rel_path )))
98
+
99
+ print (f"Copying { len (files_to_copy )} files from { source_dir } to { dest_dir } ..." )
100
+
101
+ for src_path , dst_path in tqdm (files_to_copy , desc = "Copying files" ):
102
+ os .makedirs (os .path .dirname (dst_path ), exist_ok = True )
103
+ shutil .copy2 (src_path , dst_path )
104
+
105
+
106
+ def download_file_from_s3 (bucket : str , key : str , local_path : str ) -> None :
107
+ """Download a single file from S3."""
108
+ os .makedirs (os .path .dirname (local_path ), exist_ok = True )
109
+ s3_client .download_file (bucket , key , local_path )
110
+
111
+
112
+ def upload_file_to_s3 (local_path : str , bucket : str , key : str ) -> None :
113
+ """Upload a single file to S3."""
114
+ s3_client .upload_file (local_path , bucket , key )
115
+
116
+
117
+ def copy_s3_to_local (source_bucket : str , source_prefix : str , dest_dir : str ) -> None :
118
+ """Copy files from S3 to local directory."""
119
+ os .makedirs (dest_dir , exist_ok = True )
120
+
121
+ # List all objects in source
122
+ paginator = s3_client .get_paginator ("list_objects_v2" )
123
+ pages = paginator .paginate (Bucket = source_bucket , Prefix = source_prefix )
124
+
125
+ download_tasks = []
126
+ for page in pages :
127
+ for obj in page .get ("Contents" , []):
128
+ key = obj ["Key" ]
129
+ if key .endswith ("/" ):
130
+ continue
131
+
132
+ rel_path = os .path .relpath (key , source_prefix )
133
+ local_path = os .path .join (dest_dir , rel_path )
134
+ download_tasks .append ((source_bucket , key , local_path ))
135
+
136
+ print (f"Downloading { len (download_tasks )} files from s3://{ source_bucket } /{ source_prefix } to { dest_dir } ..." )
137
+
138
+ with concurrent .futures .ThreadPoolExecutor (max_workers = 10 ) as executor :
139
+ futures = [
140
+ executor .submit (download_file_from_s3 , bucket , key , local_path )
141
+ for bucket , key , local_path in download_tasks
142
+ ]
143
+
144
+ for future in tqdm (concurrent .futures .as_completed (futures ), total = len (futures ), desc = "Downloading" ):
145
+ future .result ()
146
+
147
+
148
+ def copy_local_to_s3 (source_dir : str , dest_bucket : str , dest_prefix : str ) -> None :
149
+ """Copy files from local directory to S3."""
150
+ # Get list of files to upload
151
+ upload_tasks = []
152
+ for root , _ , files in os .walk (source_dir ):
153
+ for file in files :
154
+ local_path = os .path .join (root , file )
155
+ rel_path = os .path .relpath (local_path , source_dir )
156
+ s3_key = os .path .join (dest_prefix , rel_path )
157
+ upload_tasks .append ((local_path , dest_bucket , s3_key ))
158
+
159
+ print (f"Uploading { len (upload_tasks )} files from { source_dir } to s3://{ dest_bucket } /{ dest_prefix } ..." )
160
+
161
+ with concurrent .futures .ThreadPoolExecutor (max_workers = 10 ) as executor :
162
+ futures = [
163
+ executor .submit (upload_file_to_s3 , local_path , bucket , key )
164
+ for local_path , bucket , key in upload_tasks
165
+ ]
166
+
167
+ for future in tqdm (concurrent .futures .as_completed (futures ), total = len (futures ), desc = "Uploading" ):
168
+ future .result ()
169
+
170
+
171
+ def copy_s3_to_s3 (source_bucket : str , source_prefix : str , dest_bucket : str , dest_prefix : str ) -> None :
172
+ """Copy files from S3 to S3."""
173
+ # List all objects in source
174
+ paginator = s3_client .get_paginator ("list_objects_v2" )
175
+ pages = paginator .paginate (Bucket = source_bucket , Prefix = source_prefix )
176
+
177
+ copy_tasks = []
178
+ for page in pages :
179
+ for obj in page .get ("Contents" , []):
180
+ key = obj ["Key" ]
181
+ if key .endswith ("/" ):
182
+ continue
183
+
184
+ rel_path = os .path .relpath (key , source_prefix )
185
+ dest_key = os .path .join (dest_prefix , rel_path )
186
+ copy_source = {"Bucket" : source_bucket , "Key" : key }
187
+ copy_tasks .append ((copy_source , dest_bucket , dest_key ))
188
+
189
+ print (f"Copying { len (copy_tasks )} files from s3://{ source_bucket } /{ source_prefix } to s3://{ dest_bucket } /{ dest_prefix } ..." )
190
+
191
+ for copy_source , bucket , key in tqdm (copy_tasks , desc = "Copying" ):
192
+ s3_client .copy_object (CopySource = copy_source , Bucket = bucket , Key = key )
193
+
194
+
195
+ def prepare_checkpoint (source_path : str , dest_path : str ) -> None :
196
+ """Prepare OlmOCR checkpoint for deployment."""
197
+ # First, validate the source checkpoint
198
+ config_path = os .path .join (source_path , "config.json" )
199
+ if is_s3_path (source_path ):
200
+ config_path = f"{ source_path } /config.json"
201
+
202
+ validate_checkpoint_architecture (config_path )
203
+
204
+ # Copy model files to destination
205
+ print ("\n Copying model files..." )
206
+ if is_s3_path (source_path ) and is_s3_path (dest_path ):
207
+ # S3 to S3
208
+ source_bucket , source_prefix = parse_s3_path (source_path )
209
+ dest_bucket , dest_prefix = parse_s3_path (dest_path )
210
+ copy_s3_to_s3 (source_bucket , source_prefix , dest_bucket , dest_prefix )
211
+ elif is_s3_path (source_path ) and not is_s3_path (dest_path ):
212
+ # S3 to local
213
+ source_bucket , source_prefix = parse_s3_path (source_path )
214
+ copy_s3_to_local (source_bucket , source_prefix , dest_path )
215
+ elif not is_s3_path (source_path ) and is_s3_path (dest_path ):
216
+ # Local to S3
217
+ dest_bucket , dest_prefix = parse_s3_path (dest_path )
218
+ copy_local_to_s3 (source_path , dest_bucket , dest_prefix )
219
+ else :
220
+ # Local to local
221
+ copy_local_to_local (source_path , dest_path )
222
+
223
+ # Download tokenizer files from Hugging Face
224
+ print ("\n Downloading tokenizer files from Hugging Face..." )
225
+
226
+ if is_s3_path (dest_path ):
227
+ # Download to temp directory first, then upload to S3
228
+ with tempfile .TemporaryDirectory () as temp_dir :
229
+ # Download files
230
+ with concurrent .futures .ThreadPoolExecutor (max_workers = 6 ) as executor :
231
+ futures = [
232
+ executor .submit (download_file_from_hf , filename , temp_dir )
233
+ for filename in TOKENIZER_FILES
234
+ ]
235
+ for future in concurrent .futures .as_completed (futures ):
236
+ future .result ()
237
+
238
+ # Upload to S3
239
+ dest_bucket , dest_prefix = parse_s3_path (dest_path )
240
+ upload_tasks = []
241
+ for filename in TOKENIZER_FILES :
242
+ local_path = os .path .join (temp_dir , filename )
243
+ s3_key = os .path .join (dest_prefix , filename )
244
+ upload_tasks .append ((local_path , dest_bucket , s3_key ))
245
+
246
+ print ("Uploading tokenizer files to S3..." )
247
+ with concurrent .futures .ThreadPoolExecutor (max_workers = 6 ) as executor :
248
+ futures = [
249
+ executor .submit (upload_file_to_s3 , local_path , bucket , key )
250
+ for local_path , bucket , key in upload_tasks
251
+ ]
252
+ for future in concurrent .futures .as_completed (futures ):
253
+ future .result ()
254
+ else :
255
+ # Download directly to destination
256
+ with concurrent .futures .ThreadPoolExecutor (max_workers = 6 ) as executor :
257
+ futures = [
258
+ executor .submit (download_file_from_hf , filename , dest_path )
259
+ for filename in TOKENIZER_FILES
260
+ ]
261
+ for future in concurrent .futures .as_completed (futures ):
262
+ future .result ()
263
+
264
+ print (f"\n ✓ Successfully prepared checkpoint at { dest_path } " )
265
+
266
+
267
+ def main ():
268
+ parser = argparse .ArgumentParser (
269
+ description = "Prepare OlmOCR checkpoint for deployment" ,
270
+ formatter_class = argparse .RawDescriptionHelpFormatter ,
271
+ epilog = """
272
+ Examples:
273
+ # Local to local
274
+ python prepare_olmocr_checkpoint.py /path/to/checkpoint /path/to/output
275
+
276
+ # S3 to S3
277
+ python prepare_olmocr_checkpoint.py s3://bucket/checkpoint s3://bucket/prepared
278
+
279
+ # S3 to local
280
+ python prepare_olmocr_checkpoint.py s3://bucket/checkpoint /path/to/output
281
+
282
+ # Local to S3
283
+ python prepare_olmocr_checkpoint.py /path/to/checkpoint s3://bucket/prepared
284
+ """
285
+ )
286
+ parser .add_argument ("source" , help = "Source checkpoint path (local or S3)" )
287
+ parser .add_argument ("destination" , help = "Destination path (local or S3)" )
288
+
289
+ args = parser .parse_args ()
290
+
291
+ try :
292
+ prepare_checkpoint (args .source , args .destination )
293
+ except Exception as e :
294
+ print (f"\n ❌ Error: { e } " )
295
+ return 1
296
+
297
+ return 0
298
+
299
+
300
+ if __name__ == "__main__" :
301
+ exit (main ())
0 commit comments