Skip to content

Commit 62fc9c5

Browse files
committed
improve model saving for deepspeed
1 parent 6f81182 commit 62fc9c5

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

llm_studio/src/utils/modeling_utils.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -98,24 +98,26 @@ def save_checkpoint(model: torch.nn.Module, path: str, cfg: Any):
9898

9999
if cfg.environment.use_deepspeed:
100100
if path is not None:
101-
# gather model params from all ranks
102-
# if hasattr(cfg.training, "lora") and cfg.training.lora:
103-
# model.backbone.save_pretrained(path)
104-
model.save_checkpoint(os.path.join(path, "ds_checkpoint"))
105-
if cfg.environment._local_rank == 0:
106-
# load to cpu
107-
state_dict = get_fp32_state_dict_from_zero_checkpoint(
108-
os.path.join(path, "ds_checkpoint")
101+
# gather model params from all ranks when using Deepspeed
102+
if not model.save_16bit_model(path, "checkpoint.pth"):
103+
logger.warning(
104+
"deepspeed.save_16bit_model didn't save the model, since"
105+
" stage3_gather_16bit_weights_on_model_save=False."
106+
" Saving the full checkpoint instead"
109107
)
110-
# save as normal checkpoint that can be loaded by `load_state_dict`
111-
checkpoint = {"model": state_dict}
112-
torch.save(checkpoint, os.path.join(path, "checkpoint.pth"))
113-
shutil.rmtree(os.path.join(path, "ds_checkpoint"))
108+
model.save_checkpoint(os.path.join(path, "ds_checkpoint"))
109+
if cfg.environment._local_rank == 0:
110+
# load to cpu
111+
state_dict = get_fp32_state_dict_from_zero_checkpoint(
112+
os.path.join(path, "ds_checkpoint")
113+
)
114+
# save as normal checkpoint that can be loaded by `load_state_dict`
115+
checkpoint = {"model": state_dict}
116+
torch.save(checkpoint, os.path.join(path, "checkpoint.pth"))
117+
shutil.rmtree(os.path.join(path, "ds_checkpoint"))
114118
else:
115119
if cfg.environment._local_rank == 0:
116120
model = unwrap_model(model)
117-
# if hasattr(cfg.training, "lora") and cfg.training.lora:
118-
# model.backbone.save_pretrained(path)
119121
checkpoint = {"model": model.state_dict()}
120122
if path is not None:
121123
torch.save(checkpoint, os.path.join(path, "checkpoint.pth"))
@@ -193,7 +195,9 @@ def load_checkpoint(
193195
if weights_path is None:
194196
weights_path = cfg.architecture.pretrained_weights
195197

196-
model_weights = torch.load(weights_path, map_location="cpu")["model"]
198+
model_weights = torch.load(weights_path, map_location="cpu")
199+
if "model" in model_weights.keys():
200+
model_weights = model_weights["model"]
197201

198202
model = load_model_weights(model, model_weights, strict, cfg)
199203

@@ -224,6 +228,7 @@ def get_ds_config(cfg: Any):
224228
# zero3
225229
"stage3_prefetch_bucket_size": cfg.environment.deepspeed_stage3_prefetch_bucket_size, # noqa: E501
226230
"stage3_param_persistence_threshold": cfg.environment.deepspeed_stage3_param_persistence_threshold, # noqa: E501
231+
"stage3_gather_16bit_weights_on_model_save": True,
227232
# zero3 offload cpu
228233
# "stage3_max_live_parameters": cfg.environment.deepspeed_stage3_max_live_parameters, # noqa: E501
229234
# "stage3_max_reuse_distance": cfg.environment.deepspeed_stage3_max_reuse_distance, # noqa: E501

0 commit comments

Comments
 (0)