Skip to content

Commit 8dcfdd0

Browse files
committed
Checkpoint prep tool
1 parent c029ccd commit 8dcfdd0

File tree

2 files changed

+301
-0
lines changed

2 files changed

+301
-0
lines changed
File renamed without changes.
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Prepares OlmOCR checkpoints for deployment by:
4+
1. Validating the model architecture
5+
2. Copying model files to destination (disk or S3)
6+
3. Downloading required tokenizer files from Hugging Face
7+
8+
Usage:
9+
python prepare_olmocr_checkpoint.py <source_path> <destination_path>
10+
11+
source_path: Path to checkpoint (local or S3)
12+
destination_path: Where to save prepared checkpoint (local or S3)
13+
"""
14+
15+
import argparse
16+
import concurrent.futures
17+
import json
18+
import os
19+
import shutil
20+
import tempfile
21+
22+
import boto3
23+
import requests
24+
from smart_open import smart_open
25+
from tqdm import tqdm
26+
27+
from olmocr.s3_utils import parse_s3_path
28+
29+
# Hugging Face model ID for tokenizer files
30+
HF_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
31+
HF_BASE_URL = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main"
32+
33+
# Required tokenizer files to download from Hugging Face
34+
TOKENIZER_FILES = [
35+
"chat_template.json",
36+
"merges.txt",
37+
"preprocessor_config.json",
38+
"tokenizer.json",
39+
"tokenizer_config.json",
40+
"vocab.json"
41+
]
42+
43+
# Expected model architecture
44+
EXPECTED_ARCHITECTURE = "Qwen2_5_VLForConditionalGeneration"
45+
46+
s3_client = boto3.client("s3")
47+
48+
49+
def is_s3_path(path: str) -> bool:
50+
"""Check if a path is an S3 path."""
51+
return path.startswith("s3://")
52+
53+
54+
def download_file_from_hf(filename: str, destination_dir: str) -> None:
55+
"""Download a file from Hugging Face model repository."""
56+
url = f"{HF_BASE_URL}/{filename}"
57+
local_path = os.path.join(destination_dir, filename)
58+
59+
print(f"Downloading {filename} from Hugging Face...")
60+
response = requests.get(url, stream=True)
61+
response.raise_for_status()
62+
63+
with open(local_path, "wb") as f:
64+
for chunk in response.iter_content(chunk_size=8192):
65+
f.write(chunk)
66+
67+
print(f"Downloaded {filename}")
68+
69+
70+
def validate_checkpoint_architecture(config_path: str) -> None:
71+
"""Validate that the checkpoint has the expected architecture."""
72+
print(f"Validating checkpoint architecture from {config_path}...")
73+
74+
with smart_open(config_path, "r") as f:
75+
config_data = json.load(f)
76+
77+
architectures = config_data.get("architectures", [])
78+
if EXPECTED_ARCHITECTURE not in architectures:
79+
raise ValueError(
80+
f"Invalid model architecture. Expected '{EXPECTED_ARCHITECTURE}' "
81+
f"but found: {architectures}"
82+
)
83+
84+
print(f"✓ Valid architecture: {architectures}")
85+
86+
87+
def copy_local_to_local(source_dir: str, dest_dir: str) -> None:
88+
"""Copy files from local directory to local directory."""
89+
os.makedirs(dest_dir, exist_ok=True)
90+
91+
# Get list of files to copy
92+
files_to_copy = []
93+
for root, _, files in os.walk(source_dir):
94+
for file in files:
95+
src_path = os.path.join(root, file)
96+
rel_path = os.path.relpath(src_path, source_dir)
97+
files_to_copy.append((src_path, os.path.join(dest_dir, rel_path)))
98+
99+
print(f"Copying {len(files_to_copy)} files from {source_dir} to {dest_dir}...")
100+
101+
for src_path, dst_path in tqdm(files_to_copy, desc="Copying files"):
102+
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
103+
shutil.copy2(src_path, dst_path)
104+
105+
106+
def download_file_from_s3(bucket: str, key: str, local_path: str) -> None:
107+
"""Download a single file from S3."""
108+
os.makedirs(os.path.dirname(local_path), exist_ok=True)
109+
s3_client.download_file(bucket, key, local_path)
110+
111+
112+
def upload_file_to_s3(local_path: str, bucket: str, key: str) -> None:
113+
"""Upload a single file to S3."""
114+
s3_client.upload_file(local_path, bucket, key)
115+
116+
117+
def copy_s3_to_local(source_bucket: str, source_prefix: str, dest_dir: str) -> None:
118+
"""Copy files from S3 to local directory."""
119+
os.makedirs(dest_dir, exist_ok=True)
120+
121+
# List all objects in source
122+
paginator = s3_client.get_paginator("list_objects_v2")
123+
pages = paginator.paginate(Bucket=source_bucket, Prefix=source_prefix)
124+
125+
download_tasks = []
126+
for page in pages:
127+
for obj in page.get("Contents", []):
128+
key = obj["Key"]
129+
if key.endswith("/"):
130+
continue
131+
132+
rel_path = os.path.relpath(key, source_prefix)
133+
local_path = os.path.join(dest_dir, rel_path)
134+
download_tasks.append((source_bucket, key, local_path))
135+
136+
print(f"Downloading {len(download_tasks)} files from s3://{source_bucket}/{source_prefix} to {dest_dir}...")
137+
138+
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
139+
futures = [
140+
executor.submit(download_file_from_s3, bucket, key, local_path)
141+
for bucket, key, local_path in download_tasks
142+
]
143+
144+
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Downloading"):
145+
future.result()
146+
147+
148+
def copy_local_to_s3(source_dir: str, dest_bucket: str, dest_prefix: str) -> None:
149+
"""Copy files from local directory to S3."""
150+
# Get list of files to upload
151+
upload_tasks = []
152+
for root, _, files in os.walk(source_dir):
153+
for file in files:
154+
local_path = os.path.join(root, file)
155+
rel_path = os.path.relpath(local_path, source_dir)
156+
s3_key = os.path.join(dest_prefix, rel_path)
157+
upload_tasks.append((local_path, dest_bucket, s3_key))
158+
159+
print(f"Uploading {len(upload_tasks)} files from {source_dir} to s3://{dest_bucket}/{dest_prefix}...")
160+
161+
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
162+
futures = [
163+
executor.submit(upload_file_to_s3, local_path, bucket, key)
164+
for local_path, bucket, key in upload_tasks
165+
]
166+
167+
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Uploading"):
168+
future.result()
169+
170+
171+
def copy_s3_to_s3(source_bucket: str, source_prefix: str, dest_bucket: str, dest_prefix: str) -> None:
172+
"""Copy files from S3 to S3."""
173+
# List all objects in source
174+
paginator = s3_client.get_paginator("list_objects_v2")
175+
pages = paginator.paginate(Bucket=source_bucket, Prefix=source_prefix)
176+
177+
copy_tasks = []
178+
for page in pages:
179+
for obj in page.get("Contents", []):
180+
key = obj["Key"]
181+
if key.endswith("/"):
182+
continue
183+
184+
rel_path = os.path.relpath(key, source_prefix)
185+
dest_key = os.path.join(dest_prefix, rel_path)
186+
copy_source = {"Bucket": source_bucket, "Key": key}
187+
copy_tasks.append((copy_source, dest_bucket, dest_key))
188+
189+
print(f"Copying {len(copy_tasks)} files from s3://{source_bucket}/{source_prefix} to s3://{dest_bucket}/{dest_prefix}...")
190+
191+
for copy_source, bucket, key in tqdm(copy_tasks, desc="Copying"):
192+
s3_client.copy_object(CopySource=copy_source, Bucket=bucket, Key=key)
193+
194+
195+
def prepare_checkpoint(source_path: str, dest_path: str) -> None:
196+
"""Prepare OlmOCR checkpoint for deployment."""
197+
# First, validate the source checkpoint
198+
config_path = os.path.join(source_path, "config.json")
199+
if is_s3_path(source_path):
200+
config_path = f"{source_path}/config.json"
201+
202+
validate_checkpoint_architecture(config_path)
203+
204+
# Copy model files to destination
205+
print("\nCopying model files...")
206+
if is_s3_path(source_path) and is_s3_path(dest_path):
207+
# S3 to S3
208+
source_bucket, source_prefix = parse_s3_path(source_path)
209+
dest_bucket, dest_prefix = parse_s3_path(dest_path)
210+
copy_s3_to_s3(source_bucket, source_prefix, dest_bucket, dest_prefix)
211+
elif is_s3_path(source_path) and not is_s3_path(dest_path):
212+
# S3 to local
213+
source_bucket, source_prefix = parse_s3_path(source_path)
214+
copy_s3_to_local(source_bucket, source_prefix, dest_path)
215+
elif not is_s3_path(source_path) and is_s3_path(dest_path):
216+
# Local to S3
217+
dest_bucket, dest_prefix = parse_s3_path(dest_path)
218+
copy_local_to_s3(source_path, dest_bucket, dest_prefix)
219+
else:
220+
# Local to local
221+
copy_local_to_local(source_path, dest_path)
222+
223+
# Download tokenizer files from Hugging Face
224+
print("\nDownloading tokenizer files from Hugging Face...")
225+
226+
if is_s3_path(dest_path):
227+
# Download to temp directory first, then upload to S3
228+
with tempfile.TemporaryDirectory() as temp_dir:
229+
# Download files
230+
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
231+
futures = [
232+
executor.submit(download_file_from_hf, filename, temp_dir)
233+
for filename in TOKENIZER_FILES
234+
]
235+
for future in concurrent.futures.as_completed(futures):
236+
future.result()
237+
238+
# Upload to S3
239+
dest_bucket, dest_prefix = parse_s3_path(dest_path)
240+
upload_tasks = []
241+
for filename in TOKENIZER_FILES:
242+
local_path = os.path.join(temp_dir, filename)
243+
s3_key = os.path.join(dest_prefix, filename)
244+
upload_tasks.append((local_path, dest_bucket, s3_key))
245+
246+
print("Uploading tokenizer files to S3...")
247+
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
248+
futures = [
249+
executor.submit(upload_file_to_s3, local_path, bucket, key)
250+
for local_path, bucket, key in upload_tasks
251+
]
252+
for future in concurrent.futures.as_completed(futures):
253+
future.result()
254+
else:
255+
# Download directly to destination
256+
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
257+
futures = [
258+
executor.submit(download_file_from_hf, filename, dest_path)
259+
for filename in TOKENIZER_FILES
260+
]
261+
for future in concurrent.futures.as_completed(futures):
262+
future.result()
263+
264+
print(f"\n✓ Successfully prepared checkpoint at {dest_path}")
265+
266+
267+
def main():
268+
parser = argparse.ArgumentParser(
269+
description="Prepare OlmOCR checkpoint for deployment",
270+
formatter_class=argparse.RawDescriptionHelpFormatter,
271+
epilog="""
272+
Examples:
273+
# Local to local
274+
python prepare_olmocr_checkpoint.py /path/to/checkpoint /path/to/output
275+
276+
# S3 to S3
277+
python prepare_olmocr_checkpoint.py s3://bucket/checkpoint s3://bucket/prepared
278+
279+
# S3 to local
280+
python prepare_olmocr_checkpoint.py s3://bucket/checkpoint /path/to/output
281+
282+
# Local to S3
283+
python prepare_olmocr_checkpoint.py /path/to/checkpoint s3://bucket/prepared
284+
"""
285+
)
286+
parser.add_argument("source", help="Source checkpoint path (local or S3)")
287+
parser.add_argument("destination", help="Destination path (local or S3)")
288+
289+
args = parser.parse_args()
290+
291+
try:
292+
prepare_checkpoint(args.source, args.destination)
293+
except Exception as e:
294+
print(f"\n❌ Error: {e}")
295+
return 1
296+
297+
return 0
298+
299+
300+
if __name__ == "__main__":
301+
exit(main())

0 commit comments

Comments
 (0)