@@ -166,6 +166,20 @@ def main():
166
166
full_output_dir = os .path .join (config .training .output_dir , config .run_name )
167
167
logger .info (f"Setting output directory to: { full_output_dir } " )
168
168
169
+ # Check for existing checkpoints if resume_from_checkpoint is not specified
170
+ resume_checkpoint = config .training .resume_from_checkpoint
171
+ if resume_checkpoint is None and os .path .exists (full_output_dir ):
172
+ # Look for checkpoint directories
173
+ checkpoint_dirs = [d for d in os .listdir (full_output_dir ) if d .startswith ("checkpoint-" ) and os .path .isdir (os .path .join (full_output_dir , d ))]
174
+ if checkpoint_dirs :
175
+ # Sort by checkpoint number and get the latest
176
+ checkpoint_dirs .sort (key = lambda x : int (x .split ("-" )[1 ]))
177
+ latest_checkpoint = os .path .join (full_output_dir , checkpoint_dirs [- 1 ])
178
+ logger .info (f"Found existing checkpoint: { latest_checkpoint } " )
179
+ resume_checkpoint = latest_checkpoint
180
+ else :
181
+ logger .info ("No existing checkpoints found in output directory" )
182
+
169
183
# Set up training arguments
170
184
training_args = TrainingArguments (
171
185
output_dir = full_output_dir ,
@@ -199,7 +213,7 @@ def main():
199
213
seed = config .training .seed ,
200
214
data_seed = config .training .data_seed ,
201
215
push_to_hub = False ,
202
- resume_from_checkpoint = config . training . resume_from_checkpoint ,
216
+ resume_from_checkpoint = resume_checkpoint ,
203
217
dataloader_drop_last = config .training .dataloader_drop_last ,
204
218
dataloader_num_workers = config .training .dataloader_num_workers ,
205
219
remove_unused_columns = config .training .remove_unused_columns ,
@@ -229,7 +243,7 @@ def main():
229
243
230
244
# Start training
231
245
logger .info ("Starting training..." )
232
- train_result = trainer .train (resume_from_checkpoint = config . training . resume_from_checkpoint )
246
+ train_result = trainer .train (resume_from_checkpoint = resume_checkpoint )
233
247
234
248
# Save the final model
235
249
logger .info ("Saving final model..." )
0 commit comments