@@ -166,17 +166,17 @@ def main():
166
166
full_output_dir = os .path .join (config .training .output_dir , config .run_name )
167
167
logger .info (f"Setting output directory to: { full_output_dir } " )
168
168
169
- # Check for existing checkpoints if resume_from_checkpoint is not specified
170
- resume_checkpoint = config . training . resume_from_checkpoint
171
- if resume_checkpoint is None and os .path .exists (full_output_dir ):
169
+ # Check for existing checkpoints if any
170
+ found_resumable_checkpoint = False
171
+ if os .path .exists (full_output_dir ):
172
172
# Look for checkpoint directories
173
173
checkpoint_dirs = [d for d in os .listdir (full_output_dir ) if d .startswith ("checkpoint-" ) and os .path .isdir (os .path .join (full_output_dir , d ))]
174
174
if checkpoint_dirs :
175
175
# Sort by checkpoint number and get the latest
176
176
checkpoint_dirs .sort (key = lambda x : int (x .split ("-" )[1 ]))
177
177
latest_checkpoint = os .path .join (full_output_dir , checkpoint_dirs [- 1 ])
178
178
logger .info (f"Found existing checkpoint: { latest_checkpoint } " )
179
- resume_checkpoint = latest_checkpoint
179
+ found_resumable_checkpoint = True
180
180
else :
181
181
logger .info ("No existing checkpoints found in output directory" )
182
182
@@ -213,7 +213,6 @@ def main():
213
213
seed = config .training .seed ,
214
214
data_seed = config .training .data_seed ,
215
215
push_to_hub = False ,
216
- resume_from_checkpoint = resume_checkpoint ,
217
216
dataloader_drop_last = config .training .dataloader_drop_last ,
218
217
dataloader_num_workers = config .training .dataloader_num_workers ,
219
218
remove_unused_columns = config .training .remove_unused_columns ,
@@ -243,7 +242,7 @@ def main():
243
242
244
243
# Start training
245
244
logger .info ("Starting training..." )
246
- train_result = trainer .train (resume_from_checkpoint = resume_checkpoint )
245
+ train_result = trainer .train (resume_from_checkpoint = found_resumable_checkpoint )
247
246
248
247
# Save the final model
249
248
logger .info ("Saving final model..." )
0 commit comments