From 75517304a73e6725cb7bd6815d8c2217ada64d8d Mon Sep 17 00:00:00 2001 From: CrystalX <550820611@qq.com> Date: Thu, 24 Oct 2024 14:18:16 +0800 Subject: [PATCH] [Unified Checkpoint] Accelerate loading checkpoint by multi-thread (#9034) * [Unified Checkpoint] speed up loading checkpoint by multi thread * [Unified CHeckpoint] speed up load by multi-thread * [Unified CHeckpoint] speed up load by multi-thread * [Unified CHeckpoint] speed up load by multi-thread * Unified CHeckpoint] speed up loading checkpoint by multi-thread * Unified CHeckpoint] speed up loading checkpoint by multi-thread * Unified CHeckpoint] speed up loading checkpoint by multi-thread * Unified CHeckpoint] speed up loading checkpoint by multi-thread * Unified CHeckpoint] speed up loading checkpoint by multi-thread * Unified CHeckpoint] speed up loading checkpoint by multi-thread --- paddlenlp/transformers/model_utils.py | 105 ++++++++++++++++++++++---- 1 file changed, 89 insertions(+), 16 deletions(-) diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index 95d67ae788b2..c1af5b4a77bb 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import concurrent.futures import contextlib import copy import gc @@ -319,6 +320,65 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype: return last_dtype +def _split_keys_evenly(keys: list, n: int) -> list: + """Split a list into n lists with an equal number of elements. + + Args: + keys (list): the list to be split + n (int): number of splits + + Returns: + result: list of lists + """ + + total_len = len(keys) + base_size = total_len // n + extra = total_len % n + + result = [] + index = 0 + for _ in range(n): + part_size = base_size + 1 if extra > 0 else base_size + extra -= 1 + result.append(keys[index : index + part_size]) + index += part_size + + return result + + +def _load_part_state_dict( + keys, checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping, fliter_dict_keys, device +): + """load part state dict from checkpoint file. + + Args: + keys (list): the keys of part state dict + checkpoint_file (str): the path of checkpoint file + tensor_parallel_split_mapping (dict): mapping from key to function + fliter_dict_keys (list): filter keys in state dict + + Returns: + part_state_dict (dict): the part state dict + + """ + part_state_dict = {} + with safe_open(checkpoint_file, framework="np") as f: + for key in keys: + if fliter_dict_keys is not None and key not in fliter_dict_keys: + continue + py_safe_slice_ = f.get_slice(key) + if key in tensor_parallel_split_mapping: + weight = tensor_parallel_split_mapping[key](py_safe_slice_) + else: + weight = py_safe_slice_[:] + if device == "expected": + with device_guard(): + weight = paddle.Tensor(weight, zero_copy=True) + weight = weight._copy_to(paddle.framework._current_expected_place(), False) + part_state_dict[key] = weight + return part_state_dict + + def load_state_dict( checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None, device="cpu" ): @@ -343,21 +403,36 @@ def load_state_dict( if metadata.get("format", "np") == "pd": raise ValueError("Currently unsupport paddle weights file, use numpy instead.") if metadata.get("format", "np") == "np": + thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1")) state_dict = {} - with safe_open(checkpoint_file, framework="np") as f: - for key in f.keys(): - if fliter_dict_keys is not None and key not in fliter_dict_keys: - continue - py_safe_slice_ = f.get_slice(key) - if key in tensor_parallel_split_mapping: - weight = tensor_parallel_split_mapping[key](py_safe_slice_) - else: - weight = py_safe_slice_[:] - if device == "expected": - with device_guard(): - weight = paddle.Tensor(weight, zero_copy=True) - weight = weight._copy_to(paddle.framework._current_expected_place(), False) - state_dict[key] = weight + if thread_num <= 1: + with safe_open(checkpoint_file, framework="np") as f: + state_dict = _load_part_state_dict( + list(f.keys()), + checkpoint_file, + tensor_parallel_split_mapping, + fliter_dict_keys, + device, + ) + else: + # Load state dict in multi-thread to speed up loading + with safe_open(checkpoint_file, framework="np") as f: + keys_groups = _split_keys_evenly(list(f.keys()), thread_num) + with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor: + future_to_key = { + executor.submit( + _load_part_state_dict, + keys, + checkpoint_file, + tensor_parallel_split_mapping, + fliter_dict_keys, + device, + ): keys + for keys in keys_groups + } + for future in concurrent.futures.as_completed(future_to_key): + result = future.result() + state_dict.update(result) if device == "cpu": for k in list(state_dict.keys()): @@ -1963,7 +2038,6 @@ def _fuse_or_split_keys( if config.quantization_config.is_weight_quantize(): filter_dict_keys = None - state_dict = load_state_dict( shard_file, tp_actions if pre_tensor_parallel_split else None, filter_dict_keys ) @@ -2279,7 +2353,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): else: raise ValueError(f"Unexpected file: {resolved_archive_file} for weight conversion.") # load pt weights early so that we know which dtype to init the model under - if not is_sharded and state_dict is None: # 4. loading non-sharded ckpt from the state dict if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model_state.pdparams"):