9
9
import re
10
10
import tempfile
11
11
from pathlib import Path
12
- from typing import Optional , Sequence , Union
12
+ from typing import Any , Dict , Optional , Sequence , Union
13
13
14
14
import torch
15
15
from composer .core import Callback , Event , State , Time , TimeUnit
@@ -203,14 +203,17 @@ def _save_checkpoint(self, state: State, logger: Logger):
203
203
from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
204
204
205
205
if state .is_model_ddp :
206
+ composer_model = state .model .module
206
207
original_model : PreTrainedModel = state .model .module .model
207
208
state_dict_model = state .model .module .model
208
209
original_tokenizer = state .model .module .tokenizer
209
210
elif isinstance (state .model .model , FSDP ):
211
+ composer_model = state .model
210
212
original_model : PreTrainedModel = state .model .model .module
211
213
state_dict_model = state .model .model
212
214
original_tokenizer = state .model .tokenizer
213
215
else :
216
+ composer_model = state .model
214
217
original_model : PreTrainedModel = state .model .model
215
218
state_dict_model = state .model .model
216
219
original_tokenizer = state .model .tokenizer
@@ -237,10 +240,25 @@ def _save_checkpoint(self, state: State, logger: Logger):
237
240
copied_config .init_device = 'cpu'
238
241
239
242
log .debug (f'Creating new model instance' )
240
- # First create the model instance on meta device to avoid the
241
- # initialization cost.
242
- with init_empty_weights ():
243
- new_model_instance = type (original_model )(copied_config )
243
+
244
+ if composer_model .using_peft :
245
+ # We don't use meta here because the state dict does not contain the full
246
+ # model, only the adapter weights.
247
+ active_adapter = original_model .active_adapter
248
+ base_model = original_model .get_base_model ()
249
+ new_base_model_instance = type (base_model )(copied_config )
250
+
251
+ new_model_instance = type (original_model )(
252
+ new_base_model_instance ,
253
+ original_model .peft_config [active_adapter ])
254
+ else :
255
+ # First create the model instance on meta device to avoid the
256
+ # initialization cost.
257
+ with init_empty_weights ():
258
+ new_model_instance = type (original_model )(copied_config )
259
+
260
+ new_model_instance .to (dtype = self .dtype )
261
+ new_model_instance .load_state_dict (state_dict )
244
262
245
263
# Then load the state dict in with "assign" so that the state dict
246
264
# is loaded properly even though the model is initially on meta device.
@@ -295,12 +313,24 @@ def _save_checkpoint(self, state: State, logger: Logger):
295
313
# TODO: Remove after mlflow fixes the bug that makes this necessary
296
314
import mlflow
297
315
mlflow .store ._unity_catalog .registry .rest_store .get_feature_dependencies = lambda * args , ** kwargs : ''
298
- mlflow_logger .save_model (
299
- flavor = 'transformers' ,
300
- transformers_model = components ,
301
- path = local_save_path ,
302
- ** self .mlflow_logging_config ,
303
- )
316
+ model_saving_kwargs : Dict [str , Any ] = {
317
+ 'path' : local_save_path
318
+ }
319
+ if composer_model .using_peft :
320
+ model_saving_kwargs ['flavor' ] = 'peft'
321
+ model_saving_kwargs [
322
+ 'save_pretrained_dir' ] = temp_save_dir
323
+ model_saving_kwargs [
324
+ 'metadata' ] = self .mlflow_logging_config [
325
+ 'metadata' ]
326
+ else :
327
+ model_saving_kwargs ['flavor' ] = 'transformers'
328
+ model_saving_kwargs [
329
+ 'transformers_model' ] = components
330
+ model_saving_kwargs .update (
331
+ self .mlflow_logging_config )
332
+
333
+ mlflow_logger .save_model (** model_saving_kwargs )
304
334
305
335
license_filename = _maybe_get_license_filename (
306
336
local_save_path )
0 commit comments