Skip to content

Commit 15ee0ac

Browse files
authored
Switch to the Composer integration of LoRA (works with FSDP) (#886)
1 parent d9874d2 commit 15ee0ac

File tree

10 files changed

+624
-356
lines changed

10 files changed

+624
-356
lines changed

llmfoundry/callbacks/hf_checkpointer.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import re
1010
import tempfile
1111
from pathlib import Path
12-
from typing import Optional, Sequence, Union
12+
from typing import Any, Dict, Optional, Sequence, Union
1313

1414
import torch
1515
from composer.core import Callback, Event, State, Time, TimeUnit
@@ -203,14 +203,17 @@ def _save_checkpoint(self, state: State, logger: Logger):
203203
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
204204

205205
if state.is_model_ddp:
206+
composer_model = state.model.module
206207
original_model: PreTrainedModel = state.model.module.model
207208
state_dict_model = state.model.module.model
208209
original_tokenizer = state.model.module.tokenizer
209210
elif isinstance(state.model.model, FSDP):
211+
composer_model = state.model
210212
original_model: PreTrainedModel = state.model.model.module
211213
state_dict_model = state.model.model
212214
original_tokenizer = state.model.tokenizer
213215
else:
216+
composer_model = state.model
214217
original_model: PreTrainedModel = state.model.model
215218
state_dict_model = state.model.model
216219
original_tokenizer = state.model.tokenizer
@@ -237,10 +240,25 @@ def _save_checkpoint(self, state: State, logger: Logger):
237240
copied_config.init_device = 'cpu'
238241

239242
log.debug(f'Creating new model instance')
240-
# First create the model instance on meta device to avoid the
241-
# initialization cost.
242-
with init_empty_weights():
243-
new_model_instance = type(original_model)(copied_config)
243+
244+
if composer_model.using_peft:
245+
# We don't use meta here because the state dict does not contain the full
246+
# model, only the adapter weights.
247+
active_adapter = original_model.active_adapter
248+
base_model = original_model.get_base_model()
249+
new_base_model_instance = type(base_model)(copied_config)
250+
251+
new_model_instance = type(original_model)(
252+
new_base_model_instance,
253+
original_model.peft_config[active_adapter])
254+
else:
255+
# First create the model instance on meta device to avoid the
256+
# initialization cost.
257+
with init_empty_weights():
258+
new_model_instance = type(original_model)(copied_config)
259+
260+
new_model_instance.to(dtype=self.dtype)
261+
new_model_instance.load_state_dict(state_dict)
244262

245263
# Then load the state dict in with "assign" so that the state dict
246264
# is loaded properly even though the model is initially on meta device.
@@ -295,12 +313,24 @@ def _save_checkpoint(self, state: State, logger: Logger):
295313
# TODO: Remove after mlflow fixes the bug that makes this necessary
296314
import mlflow
297315
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
298-
mlflow_logger.save_model(
299-
flavor='transformers',
300-
transformers_model=components,
301-
path=local_save_path,
302-
**self.mlflow_logging_config,
303-
)
316+
model_saving_kwargs: Dict[str, Any] = {
317+
'path': local_save_path
318+
}
319+
if composer_model.using_peft:
320+
model_saving_kwargs['flavor'] = 'peft'
321+
model_saving_kwargs[
322+
'save_pretrained_dir'] = temp_save_dir
323+
model_saving_kwargs[
324+
'metadata'] = self.mlflow_logging_config[
325+
'metadata']
326+
else:
327+
model_saving_kwargs['flavor'] = 'transformers'
328+
model_saving_kwargs[
329+
'transformers_model'] = components
330+
model_saving_kwargs.update(
331+
self.mlflow_logging_config)
332+
333+
mlflow_logger.save_model(**model_saving_kwargs)
304334

305335
license_filename = _maybe_get_license_filename(
306336
local_save_path)

0 commit comments

Comments
 (0)