|
13 | 13 | # limitations under the License. |
14 | 14 | from __future__ import annotations |
15 | 15 |
|
| 16 | +import os |
| 17 | + |
16 | 18 | import paddle |
17 | 19 | from paddle import nn |
18 | 20 | from paddle.distributed import fleet |
|
26 | 28 | from paddlenlp.experimental.transformers.generation_utils import ( |
27 | 29 | GenerationInferenceModel, |
28 | 30 | ) |
| 31 | +from paddlenlp.experimental.transformers.utils import load_tp_checkpoint |
29 | 32 | from paddlenlp.transformers import GPTConfig, GPTPretrainedModel |
30 | 33 | from paddlenlp.transformers.gpt.modeling import GPTEmbeddings, parallel_matmul |
31 | 34 | from paddlenlp.transformers.model_outputs import ( |
32 | 35 | BaseModelOutputWithPastAndCrossAttentions, |
33 | 36 | CausalLMOutputWithCrossAttentions, |
34 | 37 | ) |
35 | 38 | from paddlenlp.transformers.model_utils import ( |
| 39 | + dtype_guard, |
36 | 40 | dy2st_nocheck_guard_context, |
| 41 | + no_init_weights, |
37 | 42 | register_base_model, |
38 | 43 | ) |
| 44 | +from paddlenlp.transformers.utils import ( |
| 45 | + ContextManagers, |
| 46 | + is_paddle_support_lazy_init, |
| 47 | + is_safetensors_available, |
| 48 | +) |
39 | 49 |
|
40 | 50 | __all__ = ["GPTInferenceModel", "GPTForCausalLMInferenceModel"] |
41 | 51 |
|
@@ -446,9 +456,47 @@ def __init__(self, config): |
446 | 456 |
|
447 | 457 | @classmethod |
448 | 458 | def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
449 | | - # TODO: Support safetensors loading. |
450 | | - kwargs["use_safetensors"] = False |
451 | | - return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
| 459 | + config = kwargs.pop("config", None) |
| 460 | + cache_dir = kwargs.pop("cache_dir", None) |
| 461 | + dtype = kwargs.pop("dtype", None) |
| 462 | + if dtype is None: |
| 463 | + dtype = config.dtype |
| 464 | + subfolder = kwargs.pop("subfolder", None) |
| 465 | + if subfolder is None: |
| 466 | + subfolder = "" |
| 467 | + variant = kwargs.pop("variant", None) |
| 468 | + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) |
| 469 | + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) |
| 470 | + |
| 471 | + init_contexts = [] |
| 472 | + if low_cpu_mem_usage or config.quantization_config.is_weight_quantize(): |
| 473 | + # Instantiate model. |
| 474 | + init_contexts.append(no_init_weights(_enable=True)) |
| 475 | + if is_paddle_support_lazy_init(): |
| 476 | + init_contexts.append(paddle.LazyGuard()) |
| 477 | + if dtype: |
| 478 | + init_contexts.append(dtype_guard(dtype)) |
| 479 | + |
| 480 | + # init the model |
| 481 | + with ContextManagers(init_contexts): |
| 482 | + model = cls(config) |
| 483 | + |
| 484 | + resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path( |
| 485 | + pretrained_model_name_or_path, |
| 486 | + cache_dir=cache_dir, |
| 487 | + subfolder=subfolder, |
| 488 | + from_hf_hub=False, |
| 489 | + from_aistudio=False, |
| 490 | + config=config, |
| 491 | + convert_from_torch=False, |
| 492 | + use_safetensors=use_safetensors, |
| 493 | + variant=variant, |
| 494 | + ) |
| 495 | + |
| 496 | + model_path = os.path.dirname(resolved_archive_file) |
| 497 | + state_dict = load_tp_checkpoint(model_path, cls, config) |
| 498 | + model.set_state_dict(state_dict) |
| 499 | + return model |
452 | 500 |
|
453 | 501 | @classmethod |
454 | 502 | def get_cache_kvs_shape( |
|
0 commit comments