@@ -134,6 +134,22 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditio
134
134
return model , tokenizer , temp_dir
135
135
136
136
137
+ def copy_additional_files (source_path : str , dest_path : str , temp_source_dir : Optional [str ] = None ) -> None :
138
+ """Copy additional config files that are needed but not saved by save_pretrained."""
139
+ # List of additional files to copy if they exist
140
+ additional_files = ["preprocessor_config.json" , "chat_template.json" ]
141
+
142
+ # Determine the actual source path (could be temp dir if downloaded from S3)
143
+ actual_source = temp_source_dir if temp_source_dir else source_path
144
+
145
+ for filename in additional_files :
146
+ source_file = os .path .join (actual_source , filename )
147
+ if os .path .exists (source_file ):
148
+ dest_file = os .path .join (dest_path , filename )
149
+ print (f"Copying { filename } to destination..." )
150
+ shutil .copy2 (source_file , dest_file )
151
+
152
+
137
153
def compress_checkpoint (source_path : str , dest_path : str ) -> None :
138
154
"""Compress OlmOCR checkpoint using FP8 quantization."""
139
155
# Load model and tokenizer
@@ -160,6 +176,9 @@ def compress_checkpoint(source_path: str, dest_path: str) -> None:
160
176
model .save_pretrained (temp_dest_dir )
161
177
tokenizer .save_pretrained (temp_dest_dir )
162
178
179
+ # Copy additional files
180
+ copy_additional_files (source_path , temp_dest_dir , temp_source_dir )
181
+
163
182
# Upload to S3
164
183
bucket , prefix = parse_s3_path (dest_path )
165
184
upload_local_to_s3 (temp_dest_dir , bucket , prefix )
@@ -169,6 +188,9 @@ def compress_checkpoint(source_path: str, dest_path: str) -> None:
169
188
os .makedirs (dest_path , exist_ok = True )
170
189
model .save_pretrained (dest_path )
171
190
tokenizer .save_pretrained (dest_path )
191
+
192
+ # Copy additional files
193
+ copy_additional_files (source_path , dest_path , temp_source_dir )
172
194
173
195
print (f"\n ✓ Successfully compressed checkpoint and saved to { dest_path } " )
174
196
0 commit comments