Skip to content

Commit 0d1026f

Browse files
committed
update return_numpy=True
1 parent cf7b2fb commit 0d1026f

File tree

5 files changed

+7
-7
lines changed

5 files changed

+7
-7
lines changed

paddlenlp/experimental/transformers/chatglm/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
631631
)
632632

633633
model_path = os.path.dirname(resolved_archive_file)
634-
state_dict = load_tp_checkpoint(model_path, cls, config)
634+
state_dict = load_tp_checkpoint(model_path, cls, config, return_numpy=True)
635635
model.set_state_dict(state_dict)
636636
return model
637637

paddlenlp/experimental/transformers/gpt/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
494494
)
495495

496496
model_path = os.path.dirname(resolved_archive_file)
497-
state_dict = load_tp_checkpoint(model_path, cls, config)
497+
state_dict = load_tp_checkpoint(model_path, cls, config, return_numpy=True)
498498
model.set_state_dict(state_dict)
499499
return model
500500

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,7 +1185,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
11851185
)
11861186

11871187
model_path = os.path.dirname(resolved_archive_file)
1188-
state_dict = load_tp_checkpoint(model_path, cls, config)
1188+
state_dict = load_tp_checkpoint(model_path, cls, config, return_numpy=True)
11891189
model.set_state_dict(state_dict)
11901190
return model
11911191

@@ -1322,7 +1322,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
13221322
)
13231323

13241324
model_path = os.path.dirname(resolved_archive_file)
1325-
state_dict = load_tp_checkpoint(model_path, cls, config)
1325+
state_dict = load_tp_checkpoint(model_path, cls, config, return_numpy=True)
13261326
model.set_state_dict(state_dict)
13271327
return model
13281328

@@ -1599,7 +1599,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
15991599
)
16001600

16011601
model_path = os.path.dirname(resolved_archive_file)
1602-
state_dict = load_tp_checkpoint(model_path, cls, config)
1602+
state_dict = load_tp_checkpoint(model_path, cls, config, return_numpy=True)
16031603
model.set_state_dict(state_dict)
16041604

16051605
return model

paddlenlp/experimental/transformers/opt/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
377377
)
378378

379379
model_path = os.path.dirname(resolved_archive_file)
380-
state_dict = load_tp_checkpoint(model_path, cls, config)
380+
state_dict = load_tp_checkpoint(model_path, cls, config, return_numpy=True)
381381
model.set_state_dict(state_dict)
382382
return model
383383

paddlenlp/experimental/transformers/qwen/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
426426
)
427427

428428
model_path = os.path.dirname(resolved_archive_file)
429-
state_dict = load_tp_checkpoint(model_path, cls, config)
429+
state_dict = load_tp_checkpoint(model_path, cls, config, return_numpy=True)
430430
model.set_state_dict(state_dict)
431431
return model
432432

0 commit comments

Comments
 (0)