Skip to content

Commit 2318d04

Browse files
fix(base_trainer): gather weights in save_pretrained under zero3 (#429)
* feat(configs): make saving optimizer state optional * fix(base_trainer): `save_pretrained` under zero3 * style * revert(configs): revert to default, save the whole, state behaviour
1 parent 92b68e4 commit 2318d04

File tree

3 files changed

+29
-20
lines changed

3 files changed

+29
-20
lines changed

trlx/data/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ class TrainConfig:
220220
checkpoint_dir: str = "ckpts"
221221
rollout_logging_dir: Optional[str] = None
222222
save_best: bool = True
223+
save_optimizer: bool = True
223224

224225
tracker: Optional[str] = "wandb"
225226
logging_dir: Optional[str] = None

trlx/models/modeling_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def save_pretrained(self, *args, **kwargs):
198198
Keyword arguments passed along to the underlying model's
199199
`save_pretrained` method.
200200
"""
201-
state_dict = kwargs.pop("state_dict", None)
201+
state_dict = kwargs.get("state_dict", None)
202202
if state_dict is None:
203203
state_dict = self.state_dict()
204204
kwargs["state_dict"] = state_dict

trlx/trainer/accelerate_base_trainer.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,16 @@ def save_pretrained(self, directory: Optional[str] = None, **kwargs):
277277
"""
278278
if directory is None:
279279
directory = os.path.join(self.config.train.checkpoint_dir, "hf_model")
280+
280281
self.accelerator.wait_for_everyone()
281-
self.accelerator.unwrap_model(self.model).save_pretrained(directory, **kwargs)
282+
self.accelerator.unwrap_model(self.model).save_pretrained(
283+
directory,
284+
save_function=self.accelerator.save,
285+
is_main_process=self.accelerator.is_main_process,
286+
state_dict=self.accelerator.get_state_dict(self.model),
287+
**kwargs,
288+
)
289+
282290
if self.accelerator.is_main_process:
283291
self.tokenizer.save_pretrained(directory)
284292

@@ -540,17 +548,24 @@ def learn(self): # noqa: C901
540548
self.scheduler.step()
541549
self.iter_count += 1
542550

543-
if self.iter_count % self.config.train.checkpoint_interval == 0:
551+
if (
552+
self.iter_count % self.config.train.checkpoint_interval == 0
553+
or self.iter_count >= self.total_steps
554+
):
544555
subfolder = f"checkpoint_{self.iter_count:0{len(str(self.total_steps))}d}"
545556
directory = os.path.join(self.config.train.checkpoint_dir, subfolder)
546-
self.save(directory)
557+
logger.info(f"Saving intermediate checkpoint into {directory}")
558+
if self.config.train.save_optimizer:
559+
self.save(directory)
560+
else:
561+
self.save_pretrained(directory)
547562

548563
stats["time/forward"] = forward_time
549564
stats["time/backward"] = backward_time
550565
for group_number, lr in enumerate(self.scheduler.get_last_lr()):
551566
stats[f"learning_rate_group_{group_number}"] = lr
552567

553-
if self.iter_count % self.config.train.eval_interval == 0:
568+
if self.iter_count % self.config.train.eval_interval == 0 or self.iter_count >= self.total_steps:
554569
results = self.evaluate()
555570
stats.update(results)
556571
if ray.is_initialized():
@@ -571,29 +586,22 @@ def learn(self): # noqa: C901
571586
if torch.distributed.is_initialized():
572587
torch.distributed.all_reduce(do_save, torch.distributed.ReduceOp.MAX)
573588
if do_save:
574-
best_path = f"{self.config.train.checkpoint_dir}/best_checkpoint"
575-
logger.info(f"Saving the best state so far into {best_path}")
576-
self.save(best_path)
589+
directory = os.path.join(self.config.train.checkpoint_dir, "best_checkpoint")
590+
logger.info(f"Saving the best state so far into {directory}")
591+
if self.config.train.save_optimizer:
592+
self.save(directory)
593+
else:
594+
self.save_pretrained(directory)
577595

578596
desc = " | ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss"))
579597
tbar.set_description(f"[{desc}]")
580598
tbar.update()
581599

582-
if self.iter_count >= self.total_steps:
583-
subfolder = f"checkpoint_{self.iter_count:0{len(str(self.total_steps))}d}"
584-
directory = os.path.join(self.config.train.checkpoint_dir, subfolder)
585-
results = self.evaluate()
586-
stats.update(results)
587-
588-
if ray.is_initialized():
589-
session.report(filter_non_scalars(stats), checkpoint=checkpoint)
590-
self.accelerator.log(stats, step=self.iter_count)
600+
self.accelerator.log(stats, step=self.iter_count)
591601

592-
self.save(directory)
602+
if self.iter_count >= self.total_steps:
593603
return results
594604

595-
self.accelerator.log(stats, step=self.iter_count)
596-
597605
self.post_backward_callback()
598606

599607
self.post_epoch_callback()

0 commit comments

Comments
 (0)