Skip to content

Commit f306a52

Browse files
committed
Compress fix
1 parent 01360ba commit f306a52

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

olmocr/train/compress_checkpoint.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,22 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditio
134134
return model, tokenizer, temp_dir
135135

136136

137+
def copy_additional_files(source_path: str, dest_path: str, temp_source_dir: Optional[str] = None) -> None:
138+
"""Copy additional config files that are needed but not saved by save_pretrained."""
139+
# List of additional files to copy if they exist
140+
additional_files = ["preprocessor_config.json", "chat_template.json"]
141+
142+
# Determine the actual source path (could be temp dir if downloaded from S3)
143+
actual_source = temp_source_dir if temp_source_dir else source_path
144+
145+
for filename in additional_files:
146+
source_file = os.path.join(actual_source, filename)
147+
if os.path.exists(source_file):
148+
dest_file = os.path.join(dest_path, filename)
149+
print(f"Copying {filename} to destination...")
150+
shutil.copy2(source_file, dest_file)
151+
152+
137153
def compress_checkpoint(source_path: str, dest_path: str) -> None:
138154
"""Compress OlmOCR checkpoint using FP8 quantization."""
139155
# Load model and tokenizer
@@ -160,6 +176,9 @@ def compress_checkpoint(source_path: str, dest_path: str) -> None:
160176
model.save_pretrained(temp_dest_dir)
161177
tokenizer.save_pretrained(temp_dest_dir)
162178

179+
# Copy additional files
180+
copy_additional_files(source_path, temp_dest_dir, temp_source_dir)
181+
163182
# Upload to S3
164183
bucket, prefix = parse_s3_path(dest_path)
165184
upload_local_to_s3(temp_dest_dir, bucket, prefix)
@@ -169,6 +188,9 @@ def compress_checkpoint(source_path: str, dest_path: str) -> None:
169188
os.makedirs(dest_path, exist_ok=True)
170189
model.save_pretrained(dest_path)
171190
tokenizer.save_pretrained(dest_path)
191+
192+
# Copy additional files
193+
copy_additional_files(source_path, dest_path, temp_source_dir)
172194

173195
print(f"\n✓ Successfully compressed checkpoint and saved to {dest_path}")
174196

0 commit comments

Comments
 (0)