From 797efa690eaa0545e33ce2a6a11a6087979b2b71 Mon Sep 17 00:00:00 2001 From: yujun <50394665+JunnYu@users.noreply.github.com> Date: Mon, 29 Jan 2024 15:26:46 +0800 Subject: [PATCH] [Fix Download] update converted logic & fix hf hub download subfolder bug (#7911) * update converted logic & fix hf hub download subfolder bug --- paddlenlp/transformers/model_utils.py | 54 +++++++++++++-------------- paddlenlp/transformers/utils.py | 2 +- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index 74df2c24bfc9..ac8450d4fdcc 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -52,7 +52,7 @@ from paddle.utils.download import is_url as is_remote_url from tqdm.auto import tqdm -from paddlenlp.utils.downloader import get_path_from_url_with_filelock, hf_file_exists +from paddlenlp.utils.downloader import get_path_from_url_with_filelock from paddlenlp.utils.env import ( CONFIG_NAME, LEGACY_CONFIG_NAME, @@ -367,28 +367,7 @@ def resolve_weight_file_from_hf_hub(repo_id: str, cache_dir: str, support_conver support_conversion (bool): whether support converting pytorch weight file to paddle weight file subfolder (str, optional) An optional value corresponding to a folder inside the repo. """ - is_local = os.path.isdir(repo_id) - if not is_local: - if hf_file_exists(repo_id, PADDLE_WEIGHTS_NAME, subfolder=subfolder): - file_name = PADDLE_WEIGHTS_NAME - assert ( - support_conversion is False - ), "Please call set convert_from_torch for paddle weights on huggingface hub, eg. Model.from_pretrained(model_name, from_hf_hub=True, convert_from_torch=False)" - elif hf_file_exists(repo_id, PYTORCH_WEIGHTS_NAME, subfolder=subfolder): - if not support_conversion: - raise EntryNotFoundError( - f"can not download `{PADDLE_WEIGHTS_NAME} from https://huggingface.co/{repo_id}` " - "and current model doesn't support conversion from pytorch weight file to paddle weight file" - ) - file_name = PYTORCH_WEIGHTS_NAME - else: - raise EntryNotFoundError( - message=f"can not find the paddle/pytorch weight file from: https://huggingface.co/{repo_id}", - response=None, - ) - else: - # for local file, we use support_conversion to select paddle or torch weight. - file_name = PYTORCH_WEIGHTS_NAME if support_conversion else PADDLE_WEIGHTS_NAME + file_name = PYTORCH_WEIGHTS_NAME if support_conversion else PADDLE_WEIGHTS_NAME file_name_list = [SAFE_WEIGHTS_NAME] + [file_name] + [PYTORCH_WEIGHTS_INDEX_NAME] + [SAFE_WEIGHTS_INDEX_NAME] resolved_file = None @@ -2156,12 +2135,31 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): or resolved_archive_file.endswith(SAFE_WEIGHTS_NAME) or resolved_archive_file.endswith(SAFE_WEIGHTS_INDEX_NAME) ): - # try to get the name-mapping info - logger.info( - f"Starting to convert pytorch weight file<{resolved_archive_file}> to " - f"paddle weight file<{os.path.join(cache_dir, PADDLE_WEIGHTS_NAME)}> ..." + converted_paddle_weights = os.path.join( + os.path.dirname(resolved_archive_file), PADDLE_WEIGHTS_NAME ) - state_dict = cls.convert(resolved_archive_file, config, cache_dir) + if not os.path.exists(converted_paddle_weights): + # try to get the name-mapping info + logger.info( + f"Starting to convert pytorch weight file <{resolved_archive_file}> to " + f"paddle weight file <{converted_paddle_weights}> ..." + ) + state_dict = cls.convert(resolved_archive_file, config, os.path.dirname(resolved_archive_file)) + else: + # try to load the converted paddle weight file + resolved_archive_file = converted_paddle_weights + sharded_metadata = None + is_sharded = False + logger.info( + f"Detect the converted Paddle weight file <{converted_paddle_weights}>. We intend to reuse this file." + ) + if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith( + "model_state.pdparams" + ): + state_dict = cls.convert_tensor_parallel(resolved_archive_file, config) + else: + state_dict = load_state_dict(resolved_archive_file) + logger.info("Loaded weights file from disk, setting weights to model.") else: raise ValueError(f"Unexpected file: {resolved_archive_file} for weight conversion.") else: diff --git a/paddlenlp/transformers/utils.py b/paddlenlp/transformers/utils.py index ecd6f77b790f..869173e55950 100644 --- a/paddlenlp/transformers/utils.py +++ b/paddlenlp/transformers/utils.py @@ -587,7 +587,7 @@ def cached_file_for_hf_hub( download_check(path_or_repo_id, full_filename, addition="from_hf_hub") resolved_file = hf_hub_download( repo_id=path_or_repo_id, - filename=full_filename, + filename=filename, cache_dir=cache_dir, subfolder=subfolder, library_name="PaddleNLP",