Skip to content

Commit 4f46f10

Browse files
committed
At least get resuming from checkpoints to work perhaps
1 parent 2375079 commit 4f46f10

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

olmocr/train/train.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,20 @@ def main():
166166
full_output_dir = os.path.join(config.training.output_dir, config.run_name)
167167
logger.info(f"Setting output directory to: {full_output_dir}")
168168

169+
# Check for existing checkpoints if resume_from_checkpoint is not specified
170+
resume_checkpoint = config.training.resume_from_checkpoint
171+
if resume_checkpoint is None and os.path.exists(full_output_dir):
172+
# Look for checkpoint directories
173+
checkpoint_dirs = [d for d in os.listdir(full_output_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(full_output_dir, d))]
174+
if checkpoint_dirs:
175+
# Sort by checkpoint number and get the latest
176+
checkpoint_dirs.sort(key=lambda x: int(x.split("-")[1]))
177+
latest_checkpoint = os.path.join(full_output_dir, checkpoint_dirs[-1])
178+
logger.info(f"Found existing checkpoint: {latest_checkpoint}")
179+
resume_checkpoint = latest_checkpoint
180+
else:
181+
logger.info("No existing checkpoints found in output directory")
182+
169183
# Set up training arguments
170184
training_args = TrainingArguments(
171185
output_dir=full_output_dir,
@@ -199,7 +213,7 @@ def main():
199213
seed=config.training.seed,
200214
data_seed=config.training.data_seed,
201215
push_to_hub=False,
202-
resume_from_checkpoint=config.training.resume_from_checkpoint,
216+
resume_from_checkpoint=resume_checkpoint,
203217
dataloader_drop_last=config.training.dataloader_drop_last,
204218
dataloader_num_workers=config.training.dataloader_num_workers,
205219
remove_unused_columns=config.training.remove_unused_columns,
@@ -229,7 +243,7 @@ def main():
229243

230244
# Start training
231245
logger.info("Starting training...")
232-
train_result = trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint)
246+
train_result = trainer.train(resume_from_checkpoint=resume_checkpoint)
233247

234248
# Save the final model
235249
logger.info("Saving final model...")

0 commit comments

Comments
 (0)