Skip to content

Commit ee8bd9b

Browse files
committed
Better resume logic I hope
1 parent 208fabc commit ee8bd9b

File tree

2 files changed

+5
-9
lines changed

2 files changed

+5
-9
lines changed

olmocr/train/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,6 @@ class TrainingConfig:
168168
seed: int = 42
169169
data_seed: Optional[int] = None
170170

171-
# Resume from checkpoint
172-
resume_from_checkpoint: Optional[str] = None
173-
174171
# Performance
175172
dataloader_drop_last: bool = True
176173
dataloader_num_workers: int = 4

olmocr/train/train.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,17 @@ def main():
166166
full_output_dir = os.path.join(config.training.output_dir, config.run_name)
167167
logger.info(f"Setting output directory to: {full_output_dir}")
168168

169-
# Check for existing checkpoints if resume_from_checkpoint is not specified
170-
resume_checkpoint = config.training.resume_from_checkpoint
171-
if resume_checkpoint is None and os.path.exists(full_output_dir):
169+
# Check for existing checkpoints if any
170+
found_resumable_checkpoint = False
171+
if os.path.exists(full_output_dir):
172172
# Look for checkpoint directories
173173
checkpoint_dirs = [d for d in os.listdir(full_output_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(full_output_dir, d))]
174174
if checkpoint_dirs:
175175
# Sort by checkpoint number and get the latest
176176
checkpoint_dirs.sort(key=lambda x: int(x.split("-")[1]))
177177
latest_checkpoint = os.path.join(full_output_dir, checkpoint_dirs[-1])
178178
logger.info(f"Found existing checkpoint: {latest_checkpoint}")
179-
resume_checkpoint = latest_checkpoint
179+
found_resumable_checkpoint = True
180180
else:
181181
logger.info("No existing checkpoints found in output directory")
182182

@@ -213,7 +213,6 @@ def main():
213213
seed=config.training.seed,
214214
data_seed=config.training.data_seed,
215215
push_to_hub=False,
216-
resume_from_checkpoint=resume_checkpoint,
217216
dataloader_drop_last=config.training.dataloader_drop_last,
218217
dataloader_num_workers=config.training.dataloader_num_workers,
219218
remove_unused_columns=config.training.remove_unused_columns,
@@ -243,7 +242,7 @@ def main():
243242

244243
# Start training
245244
logger.info("Starting training...")
246-
train_result = trainer.train(resume_from_checkpoint=resume_checkpoint)
245+
train_result = trainer.train(resume_from_checkpoint=found_resumable_checkpoint)
247246

248247
# Save the final model
249248
logger.info("Saving final model...")

0 commit comments

Comments
 (0)