|
52 | 52 | from paddle.utils.download import is_url as is_remote_url |
53 | 53 | from tqdm.auto import tqdm |
54 | 54 |
|
55 | | -from paddlenlp.utils.downloader import get_path_from_url_with_filelock, hf_file_exists |
| 55 | +from paddlenlp.utils.downloader import get_path_from_url_with_filelock |
56 | 56 | from paddlenlp.utils.env import ( |
57 | 57 | CONFIG_NAME, |
58 | 58 | LEGACY_CONFIG_NAME, |
@@ -367,28 +367,7 @@ def resolve_weight_file_from_hf_hub(repo_id: str, cache_dir: str, support_conver |
367 | 367 | support_conversion (bool): whether support converting pytorch weight file to paddle weight file |
368 | 368 | subfolder (str, optional) An optional value corresponding to a folder inside the repo. |
369 | 369 | """ |
370 | | - is_local = os.path.isdir(repo_id) |
371 | | - if not is_local: |
372 | | - if hf_file_exists(repo_id, PADDLE_WEIGHTS_NAME, subfolder=subfolder): |
373 | | - file_name = PADDLE_WEIGHTS_NAME |
374 | | - assert ( |
375 | | - support_conversion is False |
376 | | - ), "Please call set convert_from_torch for paddle weights on huggingface hub, eg. Model.from_pretrained(model_name, from_hf_hub=True, convert_from_torch=False)" |
377 | | - elif hf_file_exists(repo_id, PYTORCH_WEIGHTS_NAME, subfolder=subfolder): |
378 | | - if not support_conversion: |
379 | | - raise EntryNotFoundError( |
380 | | - f"can not download `{PADDLE_WEIGHTS_NAME} from https://huggingface.co/{repo_id}` " |
381 | | - "and current model doesn't support conversion from pytorch weight file to paddle weight file" |
382 | | - ) |
383 | | - file_name = PYTORCH_WEIGHTS_NAME |
384 | | - else: |
385 | | - raise EntryNotFoundError( |
386 | | - message=f"can not find the paddle/pytorch weight file from: https://huggingface.co/{repo_id}", |
387 | | - response=None, |
388 | | - ) |
389 | | - else: |
390 | | - # for local file, we use support_conversion to select paddle or torch weight. |
391 | | - file_name = PYTORCH_WEIGHTS_NAME if support_conversion else PADDLE_WEIGHTS_NAME |
| 370 | + file_name = PYTORCH_WEIGHTS_NAME if support_conversion else PADDLE_WEIGHTS_NAME |
392 | 371 |
|
393 | 372 | file_name_list = [SAFE_WEIGHTS_NAME] + [file_name] + [PYTORCH_WEIGHTS_INDEX_NAME] + [SAFE_WEIGHTS_INDEX_NAME] |
394 | 373 | resolved_file = None |
@@ -2156,12 +2135,31 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
2156 | 2135 | or resolved_archive_file.endswith(SAFE_WEIGHTS_NAME) |
2157 | 2136 | or resolved_archive_file.endswith(SAFE_WEIGHTS_INDEX_NAME) |
2158 | 2137 | ): |
2159 | | - # try to get the name-mapping info |
2160 | | - logger.info( |
2161 | | - f"Starting to convert pytorch weight file<{resolved_archive_file}> to " |
2162 | | - f"paddle weight file<{os.path.join(cache_dir, PADDLE_WEIGHTS_NAME)}> ..." |
| 2138 | + converted_paddle_weights = os.path.join( |
| 2139 | + os.path.dirname(resolved_archive_file), PADDLE_WEIGHTS_NAME |
2163 | 2140 | ) |
2164 | | - state_dict = cls.convert(resolved_archive_file, config, cache_dir) |
| 2141 | + if not os.path.exists(converted_paddle_weights): |
| 2142 | + # try to get the name-mapping info |
| 2143 | + logger.info( |
| 2144 | + f"Starting to convert pytorch weight file <{resolved_archive_file}> to " |
| 2145 | + f"paddle weight file <{converted_paddle_weights}> ..." |
| 2146 | + ) |
| 2147 | + state_dict = cls.convert(resolved_archive_file, config, os.path.dirname(resolved_archive_file)) |
| 2148 | + else: |
| 2149 | + # try to load the converted paddle weight file |
| 2150 | + resolved_archive_file = converted_paddle_weights |
| 2151 | + sharded_metadata = None |
| 2152 | + is_sharded = False |
| 2153 | + logger.info( |
| 2154 | + f"Detect the converted Paddle weight file <{converted_paddle_weights}>. We intend to reuse this file." |
| 2155 | + ) |
| 2156 | + if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith( |
| 2157 | + "model_state.pdparams" |
| 2158 | + ): |
| 2159 | + state_dict = cls.convert_tensor_parallel(resolved_archive_file, config) |
| 2160 | + else: |
| 2161 | + state_dict = load_state_dict(resolved_archive_file) |
| 2162 | + logger.info("Loaded weights file from disk, setting weights to model.") |
2165 | 2163 | else: |
2166 | 2164 | raise ValueError(f"Unexpected file: {resolved_archive_file} for weight conversion.") |
2167 | 2165 | else: |
|
0 commit comments