Skip to content

Commit d94606d

Browse files
committed
add model weight converter command
1 parent dbbd3ea commit d94606d

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

paddlenlp/commands/cli.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
# limitations under the License.
1414

1515
import os
16+
import json
1617
from typing import Type, List, Tuple
1718
from typer import Typer
1819
import shutil
1920
import importlib, inspect
20-
from dataclasses import dataclass
2121
from paddlenlp.transformers import AutoModel, AutoTokenizer, PretrainedModel, PretrainedTokenizer
2222
from paddlenlp.utils.log import logger
2323
from paddlenlp.utils.env import MODEL_HOME
@@ -102,6 +102,54 @@ def search(query: str):
102102
tabulate(tables, headers=['model type', 'model name'], tablefmt="grid"))
103103

104104

105+
@app.command()
106+
def convert(model_type: str,
107+
config_or_model_name: str,
108+
pytorch_checkpoint_path: str = 'pytorch',
109+
dump_output: str = "model_state.pdparams"):
110+
# convert pytorch weight file to paddle weight file
111+
112+
# Args:
113+
# model_type (str): the name of target paddle model name, which can be: bert, bert-base-uncased
114+
# torch_checkpoint_path (str, optional): the path of target pytorch weight file . Defaults to 'pytorch'.
115+
116+
# 1. resolve pytorch weight file path
117+
if os.path.isdir(pytorch_checkpoint_path):
118+
pytorch_checkpoint_path = os.path.join(pytorch_checkpoint_path,
119+
"pytorch_model.bin")
120+
if not os.path.isfile(pytorch_checkpoint_path):
121+
raise FileNotFoundError(
122+
"pytorch checkpoint file {} not found".format(
123+
pytorch_checkpoint_path))
124+
elif not os.path.exists(pytorch_checkpoint_path):
125+
raise FileNotFoundError("pytorch checkpoint file {} not found".format(
126+
pytorch_checkpoint_path))
127+
128+
def resolve_configuration(model_class: Type[PretrainedModel]) -> dict:
129+
if config_or_model_name in model_class.pretrained_init_configuration:
130+
return model_class.pretrained_init_configuration[
131+
config_or_model_name]
132+
assert os.path.isfile(
133+
config_or_model_name
134+
), f'can"t not find the configuration file by <{config_or_model_name}>'
135+
with open(config_or_model_name, 'r', encoding='utf-8') as f:
136+
config = json.load(f)
137+
return config
138+
139+
# 2. convert different model weight file with
140+
if model_type == 'bert':
141+
from paddlenlp.transformers.bert.modeling import convert_pytorch_weights, BertModel
142+
config = resolve_configuration(BertModel)
143+
model = BertModel(**config)
144+
convert_pytorch_weights(model,
145+
pytorch_checkpoint_path=pytorch_checkpoint_path)
146+
elif model_type == 'albert':
147+
from paddlenlp.transformers.albert.modeling import AlbertModel
148+
config = resolve_configuration(AlbertModel)
149+
model = AlbertModel(**config)
150+
# call `convert` method
151+
152+
105153
def main():
106154
"""the PaddleNLPCLI entry"""
107155
app()

paddlenlp/transformers/bert/modeling.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415
import warnings
16+
import os
1517

1618
import paddle
1719
import paddle.nn as nn
@@ -23,6 +25,7 @@
2325
FusedTransformerEncoderLayer = None
2426
from dataclasses import dataclass
2527
from typing import List, Optional, Tuple, Union
28+
from ...utils.log import logger
2629
from .. import PretrainedModel, register_base_model
2730
from ..model_outputs import (
2831
BaseModelOutputWithPastAndCrossAttentions,
@@ -1667,3 +1670,14 @@ def forward(self,
16671670
hidden_states=outputs.hidden_states,
16681671
attentions=outputs.attentions,
16691672
)
1673+
1674+
1675+
def convert_pytorch_weights(model: BertPretrainedModel,
1676+
pytorch_checkpoint_path: str):
1677+
# 1. load the pytorch model weight file
1678+
import torch
1679+
torch_weight: Dict[str, Any] = torch.load(torch_file)
1680+
paddle_weight = {}
1681+
1682+
# 2. load mapping configuration
1683+
# TODO(wj-Mcat): from existing codebase

0 commit comments

Comments
 (0)