Skip to content

Commit 6e101b2

Browse files
kylesayrsdsikka
andauthored
[Callbacks] Consolidate Saving Methods (#1168)
## Purpose ## * Simplify all methods of saving into one point, namely the wrapped `save_pretrained` function * Precursor to #1160 * Needed for having a single point for saving on top of existing recipes ## Background ## All the things needed to be done during saving 1. Save the model weights, potentially compressed 2. Save the processor 3. Update the recipe checkpoint 4. Copy any necessary python files from the model cache 5. Only save on the main process After these changes, (1, 2, 3, 4) will be done within the `save_pretrained` function, and (5) will be the responsibility of the caller. (3) will be implemented by #1160 so as not to conflict with existing logic in pre_init All of the places where a model is saved are * If an output dir is specified, at the end of the main function * Between stages of the stage runner * Between epochs of the HF Trainer * By the user after oneshot/training completes After these changes, all of these will be replaced by a single `save_checkpoint` function which calls `save_pretrained` to do all the necessary things ## Changes ## * Remove `save_model_and_recipe` * Saving recipes is now done by `save_pretrained` function * Implement `save_checkpoint` * Single entrypoint for saving a model and its processor * Performs actions (1, 2, 4) * Replace all locations where a model is saved with `save_checkpoint` * All applicable callers with only saving on the main process (5) * Remove support for `modify_fsdp_model_save_pretrained` and `unwrap_and_export_model`, to be added back in a future release --------- Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent d810e4a commit 6e101b2

File tree

6 files changed

+64
-194
lines changed

6 files changed

+64
-194
lines changed

src/llmcompressor/pytorch/model_load/helpers.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from loguru import logger
77
from safetensors import safe_open
88
from torch.nn import Module
9+
from transformers import PreTrainedModel
910

1011
from llmcompressor.core import active_session, create_session, pre_initialize_structure
1112
from llmcompressor.typing import Processor
@@ -14,20 +15,19 @@
1415

1516
__all__ = [
1617
"initialize_recipe",
17-
"save_model_and_recipe",
1818
"copy_python_files_from_model_cache",
1919
"fallback_to_cpu",
2020
"parse_dtype",
2121
"get_session_model",
2222
"get_completed_stages",
2323
"save_completed_stages",
24+
"save_checkpoint",
2425
]
2526

2627

2728
def initialize_recipe(model: Module, recipe_path: str):
2829
"""
2930
Initializes a recipe that has been previously applied to the model
30-
3131
:param model: PyTorch model to apply structure to
3232
:param recipe_path: path to recipe to apply to the model
3333
"""
@@ -49,43 +49,22 @@ def initialize_recipe(model: Module, recipe_path: str):
4949
logger.info(f"Applied {msg} to the model")
5050

5151

52-
def save_model_and_recipe(
53-
model: Module,
52+
def save_checkpoint(
5453
save_path: str,
55-
processor: Optional[Processor] = None,
56-
save_safetensors: bool = False,
57-
save_compressed: bool = False,
54+
model: PreTrainedModel,
55+
processor: Processor,
56+
save_safetensors: bool = True,
57+
save_compressed: bool = True,
5858
):
59-
"""
60-
Save a model, processor and the currently loaded recipe to file
61-
62-
:param model: pytorch model to save
63-
:param save_path: path to save output to
64-
:param processor: model processor or tokenizer to save
65-
:param save_safetensors: whether to save as safetensors or pickle (bin)
66-
:param save_compressed: whether to compress sparse weights on disk
67-
"""
68-
# avoid circular import
69-
from llmcompressor.transformers.utils.helpers import RECIPE_FILE_NAME
70-
59+
# saving the model also saves the recipe
7160
model.save_pretrained(
72-
save_path, save_compressed=save_compressed, safe_serialization=save_safetensors
61+
save_path,
62+
save_safetensors=save_safetensors,
63+
save_compressed=save_compressed,
7364
)
74-
7565
if processor is not None:
7666
processor.save_pretrained(save_path)
7767

78-
logger.info("Saving output to {}".format(os.path.abspath(save_path)))
79-
80-
recipe_path = os.path.join(save_path, RECIPE_FILE_NAME)
81-
session = active_session()
82-
recipe_yaml_str = session.get_serialized_recipe()
83-
with open(recipe_path, "w") as fp:
84-
fp.write(recipe_yaml_str)
85-
86-
# copy python files from cache dir to save_path if any
87-
copy_python_files_from_model_cache(model, save_path)
88-
8968

9069
def fallback_to_cpu(device: str) -> str:
9170
"""

src/llmcompressor/transformers/finetune/runner.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from llmcompressor.pytorch.model_load.helpers import (
1818
get_completed_stages,
1919
get_session_model,
20+
save_checkpoint,
2021
save_completed_stages,
2122
)
2223
from llmcompressor.recipe import Recipe, StageRunType
@@ -26,7 +27,6 @@
2627
make_dataset_splits,
2728
)
2829
from llmcompressor.typing import Processor
29-
from llmcompressor.utils.fsdp.helpers import save_model_and_recipe
3030

3131

3232
class StageRunner:
@@ -231,14 +231,20 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
231231

232232
checkpoint = None
233233

234-
if self._training_args.output_dir:
235-
save_model_and_recipe(
236-
model=self.trainer.model,
234+
# save model between stages
235+
if (
236+
self._training_args.output_dir
237+
!= TrainingArguments.__dataclass_fields__["output_dir"].default
238+
and self.trainer.accelerator.is_main_process
239+
):
240+
save_checkpoint(
237241
save_path=self._output_dir,
242+
model=self.trainer.model,
238243
processor=self.processor,
239244
save_safetensors=self._training_args.save_safetensors,
240245
save_compressed=self._model_args.save_compressed,
241246
)
247+
self.trainer.accelerator.wait_for_everyone()
242248

243249
# save stage to checkpoint dir
244250
if self.trainer.accelerator.is_main_process:

src/llmcompressor/transformers/finetune/session_mixin.py

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,13 @@
2323
from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import (
2424
KDModelWrapper,
2525
)
26-
from llmcompressor.pytorch.model_load.helpers import get_session_model
26+
from llmcompressor.pytorch.model_load.helpers import get_session_model, save_checkpoint
2727
from llmcompressor.pytorch.utils import ModuleSparsificationInfo
28-
from llmcompressor.transformers import RECIPE_FILE_NAME
2928
from llmcompressor.transformers.finetune.callbacks import (
3029
DisableHalfPrecisionCallback,
3130
TrainingLoopCallbacks,
3231
)
3332
from llmcompressor.utils.fsdp.context import summon_full_params_context
34-
from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_pretrained_fsdp
3533
from llmcompressor.utils.pytorch import qat_active
3634

3735
if TYPE_CHECKING:
@@ -64,8 +62,8 @@ class SessionManagerMixIn:
6462
def __init__(
6563
self,
6664
recipe: str,
65+
data_args: "DatasetArguments",
6766
model_args: "ModelArguments",
68-
data_args: Optional["DatasetArguments"] = None,
6967
teacher: Optional[Union[Module, str]] = None,
7068
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
7169
**kwargs,
@@ -183,7 +181,6 @@ def initialize_structure(self, stage: Optional[str] = None):
183181
"""
184182
Initialize any recipe structural changes such as quantization on the model,
185183
return immediately if session has already been initialized
186-
187184
:param stage: Optional stage of recipe to run, or None to run all stages
188185
"""
189186
session = active_session()
@@ -399,44 +396,19 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False):
399396

400397
# knowledge distillation requires making wrappers transparent during
401398
if isinstance(self.model, KDModelWrapper):
402-
self.model.prepare_for_save()
399+
self.model.prepare_for_save() # TODO: move to finalize
403400

404-
if not is_fsdp_model(self.model):
405-
self.model.save_pretrained(
401+
# save checkpoint
402+
self.save_state()
403+
if self.accelerator.is_main_process:
404+
processor = getattr(self, "processing_class", self.tokenizer)
405+
save_checkpoint(
406406
output_dir,
407-
save_compressed=self.model_args.save_compressed,
408-
safe_serialization=self.args.save_safetensors,
409-
)
410-
else: # FSDP model
411-
save_pretrained_fsdp(
412407
model=self.model,
413-
accelerator=self.accelerator,
414-
output_dir=output_dir,
408+
processor=processor,
409+
save_safetensors=self.args.save_safetensors,
415410
save_compressed=self.model_args.save_compressed,
416-
save_safetensors=self.metadata.get("save_safetensors", False),
417-
)
418-
419-
self.save_state()
420-
processor = getattr(self, "processing_class", self.tokenizer)
421-
if processor is not None:
422-
processor.save_pretrained(output_dir)
423-
424-
if not self.recipe:
425-
return
426-
427-
if self.accelerator.is_main_process:
428-
# save recipe, will contain modifiers from the model's original recipe as
429-
# well as those added from self.recipe
430-
recipe_path = os.path.join(output_dir, RECIPE_FILE_NAME)
431-
session = active_session()
432-
recipe_yaml_str = session.get_serialized_recipe()
433-
with open(recipe_path, "w") as fp:
434-
fp.write(recipe_yaml_str)
435-
436-
logger.info(
437-
f"Saved LLM Compressor recipe with model state to {recipe_path}"
438411
)
439-
440412
self.accelerator.wait_for_everyone()
441413

442414
if isinstance(self.model, KDModelWrapper):

src/llmcompressor/transformers/finetune/text_generation.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@
4646
get_session_model,
4747
initialize_recipe,
4848
parse_dtype,
49+
save_checkpoint,
4950
)
5051
from llmcompressor.recipe import Recipe, StageRunType
5152
from llmcompressor.transformers.finetune.runner import StageRunner
5253
from llmcompressor.transformers.finetune.trainer import Trainer
5354
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
54-
modify_fsdp_model_save_pretrained,
5555
modify_save_pretrained,
5656
patch_tied_tensors_bug,
5757
)
@@ -415,7 +415,10 @@ def main(
415415

416416
# wrap model.save_pretrained
417417
if is_fsdp_model(model):
418-
modify_fsdp_model_save_pretrained(trainer, processor)
418+
raise NotImplementedError(
419+
"FSDP models are not supported in the current release but will be "
420+
"suported in future releases of LLM Compressor"
421+
)
419422
else:
420423
modify_save_pretrained(model)
421424

@@ -440,16 +443,19 @@ def main(
440443
stage_runner.train(checkpoint)
441444

442445
# save if model was provided as a string or custom output_dir was set
443-
444446
if isinstance(model_args.model, str) or (
445447
training_args.output_dir
446448
!= TrainingArguments.__dataclass_fields__["output_dir"].default
449+
and trainer.accelerator.is_main_process
447450
):
448-
model.save_pretrained(
449-
training_args.output_dir, save_compressed=model_args.save_compressed
451+
save_checkpoint(
452+
save_path=training_args.output_dir,
453+
model=model,
454+
processor=processor,
455+
save_safetensors=True,
456+
save_compressed=model_args.save_compressed,
450457
)
451-
if processor is not None:
452-
processor.save_pretrained(training_args.output_dir)
458+
trainer.accelerator.wait_for_everyone()
453459

454460
# Clean up the CompressionSession before exit if requested
455461
if recipe_args.clear_sparse_session:

0 commit comments

Comments
 (0)