Skip to content

Commit

Permalink
[Unified Checkpoint] Accelerate loading checkpoint by multi-thread (#…
Browse files Browse the repository at this point in the history
…9034)

* [Unified Checkpoint] speed up loading checkpoint by multi thread

* [Unified CHeckpoint] speed up load by multi-thread

* [Unified CHeckpoint] speed up load by multi-thread

* [Unified CHeckpoint] speed up load by multi-thread

* Unified CHeckpoint] speed up loading checkpoint by  multi-thread

* Unified CHeckpoint] speed up loading checkpoint by  multi-thread

* Unified CHeckpoint] speed up loading checkpoint by  multi-thread

* Unified CHeckpoint] speed up loading checkpoint by  multi-thread

* Unified CHeckpoint] speed up loading checkpoint by  multi-thread

* Unified CHeckpoint] speed up loading checkpoint by  multi-thread
  • Loading branch information
Crystal-X-111 authored Oct 24, 2024
1 parent 6211e3d commit 7551730
Showing 1 changed file with 89 additions and 16 deletions.
105 changes: 89 additions & 16 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

import concurrent.futures
import contextlib
import copy
import gc
Expand Down Expand Up @@ -319,6 +320,65 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:
return last_dtype


def _split_keys_evenly(keys: list, n: int) -> list:
"""Split a list into n lists with an equal number of elements.
Args:
keys (list): the list to be split
n (int): number of splits
Returns:
result: list of lists
"""

total_len = len(keys)
base_size = total_len // n
extra = total_len % n

result = []
index = 0
for _ in range(n):
part_size = base_size + 1 if extra > 0 else base_size
extra -= 1
result.append(keys[index : index + part_size])
index += part_size

return result


def _load_part_state_dict(
keys, checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping, fliter_dict_keys, device
):
"""load part state dict from checkpoint file.
Args:
keys (list): the keys of part state dict
checkpoint_file (str): the path of checkpoint file
tensor_parallel_split_mapping (dict): mapping from key to function
fliter_dict_keys (list): filter keys in state dict
Returns:
part_state_dict (dict): the part state dict
"""
part_state_dict = {}
with safe_open(checkpoint_file, framework="np") as f:
for key in keys:
if fliter_dict_keys is not None and key not in fliter_dict_keys:
continue
py_safe_slice_ = f.get_slice(key)
if key in tensor_parallel_split_mapping:
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
else:
weight = py_safe_slice_[:]
if device == "expected":
with device_guard():
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
part_state_dict[key] = weight
return part_state_dict


def load_state_dict(
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None, device="cpu"
):
Expand All @@ -343,21 +403,36 @@ def load_state_dict(
if metadata.get("format", "np") == "pd":
raise ValueError("Currently unsupport paddle weights file, use numpy instead.")
if metadata.get("format", "np") == "np":
thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1"))
state_dict = {}
with safe_open(checkpoint_file, framework="np") as f:
for key in f.keys():
if fliter_dict_keys is not None and key not in fliter_dict_keys:
continue
py_safe_slice_ = f.get_slice(key)
if key in tensor_parallel_split_mapping:
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
else:
weight = py_safe_slice_[:]
if device == "expected":
with device_guard():
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
state_dict[key] = weight
if thread_num <= 1:
with safe_open(checkpoint_file, framework="np") as f:
state_dict = _load_part_state_dict(
list(f.keys()),
checkpoint_file,
tensor_parallel_split_mapping,
fliter_dict_keys,
device,
)
else:
# Load state dict in multi-thread to speed up loading
with safe_open(checkpoint_file, framework="np") as f:
keys_groups = _split_keys_evenly(list(f.keys()), thread_num)
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor:
future_to_key = {
executor.submit(
_load_part_state_dict,
keys,
checkpoint_file,
tensor_parallel_split_mapping,
fliter_dict_keys,
device,
): keys
for keys in keys_groups
}
for future in concurrent.futures.as_completed(future_to_key):
result = future.result()
state_dict.update(result)

if device == "cpu":
for k in list(state_dict.keys()):
Expand Down Expand Up @@ -1963,7 +2038,6 @@ def _fuse_or_split_keys(

if config.quantization_config.is_weight_quantize():
filter_dict_keys = None

state_dict = load_state_dict(
shard_file, tp_actions if pre_tensor_parallel_split else None, filter_dict_keys
)
Expand Down Expand Up @@ -2279,7 +2353,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
else:
raise ValueError(f"Unexpected file: {resolved_archive_file} for weight conversion.")
# load pt weights early so that we know which dtype to init the model under

if not is_sharded and state_dict is None:
# 4. loading non-sharded ckpt from the state dict
if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model_state.pdparams"):
Expand Down

0 comments on commit 7551730

Please sign in to comment.