|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import os |
| 16 | +import json |
16 | 17 | from typing import Type, List, Tuple |
17 | 18 | from typer import Typer |
18 | 19 | import shutil |
19 | 20 | import importlib, inspect |
20 | | -from dataclasses import dataclass |
21 | 21 | from paddlenlp.transformers import AutoModel, AutoTokenizer, PretrainedModel, PretrainedTokenizer |
22 | 22 | from paddlenlp.utils.log import logger |
23 | 23 | from paddlenlp.utils.env import MODEL_HOME |
@@ -102,6 +102,54 @@ def search(query: str): |
102 | 102 | tabulate(tables, headers=['model type', 'model name'], tablefmt="grid")) |
103 | 103 |
|
104 | 104 |
|
| 105 | +@app.command() |
| 106 | +def convert(model_type: str, |
| 107 | + config_or_model_name: str, |
| 108 | + pytorch_checkpoint_path: str = 'pytorch', |
| 109 | + dump_output: str = "model_state.pdparams"): |
| 110 | + # convert pytorch weight file to paddle weight file |
| 111 | + |
| 112 | + # Args: |
| 113 | + # model_type (str): the name of target paddle model name, which can be: bert, bert-base-uncased |
| 114 | + # torch_checkpoint_path (str, optional): the path of target pytorch weight file . Defaults to 'pytorch'. |
| 115 | + |
| 116 | + # 1. resolve pytorch weight file path |
| 117 | + if os.path.isdir(pytorch_checkpoint_path): |
| 118 | + pytorch_checkpoint_path = os.path.join(pytorch_checkpoint_path, |
| 119 | + "pytorch_model.bin") |
| 120 | + if not os.path.isfile(pytorch_checkpoint_path): |
| 121 | + raise FileNotFoundError( |
| 122 | + "pytorch checkpoint file {} not found".format( |
| 123 | + pytorch_checkpoint_path)) |
| 124 | + elif not os.path.exists(pytorch_checkpoint_path): |
| 125 | + raise FileNotFoundError("pytorch checkpoint file {} not found".format( |
| 126 | + pytorch_checkpoint_path)) |
| 127 | + |
| 128 | + def resolve_configuration(model_class: Type[PretrainedModel]) -> dict: |
| 129 | + if config_or_model_name in model_class.pretrained_init_configuration: |
| 130 | + return model_class.pretrained_init_configuration[ |
| 131 | + config_or_model_name] |
| 132 | + assert os.path.isfile( |
| 133 | + config_or_model_name |
| 134 | + ), f'can"t not find the configuration file by <{config_or_model_name}>' |
| 135 | + with open(config_or_model_name, 'r', encoding='utf-8') as f: |
| 136 | + config = json.load(f) |
| 137 | + return config |
| 138 | + |
| 139 | + # 2. convert different model weight file with |
| 140 | + if model_type == 'bert': |
| 141 | + from paddlenlp.transformers.bert.modeling import convert_pytorch_weights, BertModel |
| 142 | + config = resolve_configuration(BertModel) |
| 143 | + model = BertModel(**config) |
| 144 | + convert_pytorch_weights(model, |
| 145 | + pytorch_checkpoint_path=pytorch_checkpoint_path) |
| 146 | + elif model_type == 'albert': |
| 147 | + from paddlenlp.transformers.albert.modeling import AlbertModel |
| 148 | + config = resolve_configuration(AlbertModel) |
| 149 | + model = AlbertModel(**config) |
| 150 | + # call `convert` method |
| 151 | + |
| 152 | + |
105 | 153 | def main(): |
106 | 154 | """the PaddleNLPCLI entry""" |
107 | 155 | app() |
0 commit comments