Skip to content

Commit cf7b2fb

Browse files
committed
update opt, gpt predict
1 parent 33fde67 commit cf7b2fb

File tree

2 files changed

+102
-6
lines changed

2 files changed

+102
-6
lines changed

paddlenlp/experimental/transformers/gpt/modeling.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import os
17+
1618
import paddle
1719
from paddle import nn
1820
from paddle.distributed import fleet
@@ -26,16 +28,24 @@
2628
from paddlenlp.experimental.transformers.generation_utils import (
2729
GenerationInferenceModel,
2830
)
31+
from paddlenlp.experimental.transformers.utils import load_tp_checkpoint
2932
from paddlenlp.transformers import GPTConfig, GPTPretrainedModel
3033
from paddlenlp.transformers.gpt.modeling import GPTEmbeddings, parallel_matmul
3134
from paddlenlp.transformers.model_outputs import (
3235
BaseModelOutputWithPastAndCrossAttentions,
3336
CausalLMOutputWithCrossAttentions,
3437
)
3538
from paddlenlp.transformers.model_utils import (
39+
dtype_guard,
3640
dy2st_nocheck_guard_context,
41+
no_init_weights,
3742
register_base_model,
3843
)
44+
from paddlenlp.transformers.utils import (
45+
ContextManagers,
46+
is_paddle_support_lazy_init,
47+
is_safetensors_available,
48+
)
3949

4050
__all__ = ["GPTInferenceModel", "GPTForCausalLMInferenceModel"]
4151

@@ -446,9 +456,47 @@ def __init__(self, config):
446456

447457
@classmethod
448458
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
449-
# TODO: Support safetensors loading.
450-
kwargs["use_safetensors"] = False
451-
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
459+
config = kwargs.pop("config", None)
460+
cache_dir = kwargs.pop("cache_dir", None)
461+
dtype = kwargs.pop("dtype", None)
462+
if dtype is None:
463+
dtype = config.dtype
464+
subfolder = kwargs.pop("subfolder", None)
465+
if subfolder is None:
466+
subfolder = ""
467+
variant = kwargs.pop("variant", None)
468+
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
469+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
470+
471+
init_contexts = []
472+
if low_cpu_mem_usage or config.quantization_config.is_weight_quantize():
473+
# Instantiate model.
474+
init_contexts.append(no_init_weights(_enable=True))
475+
if is_paddle_support_lazy_init():
476+
init_contexts.append(paddle.LazyGuard())
477+
if dtype:
478+
init_contexts.append(dtype_guard(dtype))
479+
480+
# init the model
481+
with ContextManagers(init_contexts):
482+
model = cls(config)
483+
484+
resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path(
485+
pretrained_model_name_or_path,
486+
cache_dir=cache_dir,
487+
subfolder=subfolder,
488+
from_hf_hub=False,
489+
from_aistudio=False,
490+
config=config,
491+
convert_from_torch=False,
492+
use_safetensors=use_safetensors,
493+
variant=variant,
494+
)
495+
496+
model_path = os.path.dirname(resolved_archive_file)
497+
state_dict = load_tp_checkpoint(model_path, cls, config)
498+
model.set_state_dict(state_dict)
499+
return model
452500

453501
@classmethod
454502
def get_cache_kvs_shape(

paddlenlp/experimental/transformers/opt/modeling.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616
from __future__ import annotations
1717

18+
import os
19+
1820
import numpy as np
1921
import paddle
2022
import paddle.nn as nn
@@ -26,13 +28,21 @@
2628
from paddlenlp.experimental.transformers.generation_utils import (
2729
GenerationInferenceModel,
2830
)
31+
from paddlenlp.experimental.transformers.utils import load_tp_checkpoint
2932
from paddlenlp.transformers import OPTPretrainedModel
3033
from paddlenlp.transformers.model_utils import (
34+
dtype_guard,
3135
dy2st_nocheck_guard_context,
36+
no_init_weights,
3237
register_base_model,
3338
)
3439
from paddlenlp.transformers.opt.configuration import OPTConfig
3540
from paddlenlp.transformers.opt.modeling import OPTEmbeddings, OPTLMHead
41+
from paddlenlp.transformers.utils import (
42+
ContextManagers,
43+
is_paddle_support_lazy_init,
44+
is_safetensors_available,
45+
)
3646

3747
__all__ = ["OPTForCausalLMInferenceModel", "OPTForBlip2InferenceModel"]
3848

@@ -329,9 +339,47 @@ def __init__(self, config: OPTConfig, **kwargs):
329339

330340
@classmethod
331341
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
332-
# TODO: Support safetensors loading.
333-
kwargs["use_safetensors"] = kwargs.get("use_safetensors", False)
334-
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
342+
config = kwargs.pop("config", None)
343+
cache_dir = kwargs.pop("cache_dir", None)
344+
dtype = kwargs.pop("dtype", None)
345+
if dtype is None:
346+
dtype = config.dtype
347+
subfolder = kwargs.pop("subfolder", None)
348+
if subfolder is None:
349+
subfolder = ""
350+
variant = kwargs.pop("variant", None)
351+
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
352+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
353+
354+
init_contexts = []
355+
if low_cpu_mem_usage or config.quantization_config.is_weight_quantize():
356+
# Instantiate model.
357+
init_contexts.append(no_init_weights(_enable=True))
358+
if is_paddle_support_lazy_init():
359+
init_contexts.append(paddle.LazyGuard())
360+
if dtype:
361+
init_contexts.append(dtype_guard(dtype))
362+
363+
# init the model
364+
with ContextManagers(init_contexts):
365+
model = cls(config)
366+
367+
resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path(
368+
pretrained_model_name_or_path,
369+
cache_dir=cache_dir,
370+
subfolder=subfolder,
371+
from_hf_hub=False,
372+
from_aistudio=False,
373+
config=config,
374+
convert_from_torch=False,
375+
use_safetensors=use_safetensors,
376+
variant=variant,
377+
)
378+
379+
model_path = os.path.dirname(resolved_archive_file)
380+
state_dict = load_tp_checkpoint(model_path, cls, config)
381+
model.set_state_dict(state_dict)
382+
return model
335383

336384
@classmethod
337385
def get_cache_kvs_shape(

0 commit comments

Comments
 (0)