Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified Checkpoint] Accelerate loading checkpoint by multi-thread #9034

Merged
merged 11 commits into from
Oct 24, 2024
Merged
105 changes: 89 additions & 16 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

import concurrent.futures
import contextlib
import copy
import gc
Expand Down Expand Up @@ -319,6 +320,65 @@
return last_dtype


def _split_keys_evenly(keys: list, n: int) -> list:
"""Split a list into n lists with an equal number of elements.

Args:
keys (list): the list to be split
n (int): number of splits

Returns:
result: list of lists
"""

total_len = len(keys)
base_size = total_len // n
extra = total_len % n

Check warning on line 336 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L334-L336

Added lines #L334 - L336 were not covered by tests

result = []
index = 0
for _ in range(n):
part_size = base_size + 1 if extra > 0 else base_size
extra -= 1
result.append(keys[index : index + part_size])
index += part_size

Check warning on line 344 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L338-L344

Added lines #L338 - L344 were not covered by tests

return result

Check warning on line 346 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L346

Added line #L346 was not covered by tests


def _load_part_state_dict(
keys, checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping, fliter_dict_keys, device
):
"""load part state dict from checkpoint file.

Args:
keys (list): the keys of part state dict
checkpoint_file (str): the path of checkpoint file
tensor_parallel_split_mapping (dict): mapping from key to function
fliter_dict_keys (list): filter keys in state dict

Returns:
part_state_dict (dict): the part state dict

"""
part_state_dict = {}
with safe_open(checkpoint_file, framework="np") as f:
for key in keys:
if fliter_dict_keys is not None and key not in fliter_dict_keys:
continue

Check warning on line 368 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L368

Added line #L368 was not covered by tests
py_safe_slice_ = f.get_slice(key)
if key in tensor_parallel_split_mapping:
weight = tensor_parallel_split_mapping[key](py_safe_slice_)

Check warning on line 371 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L371

Added line #L371 was not covered by tests
else:
weight = py_safe_slice_[:]
if device == "expected":
with device_guard():
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)

Check warning on line 377 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L375-L377

Added lines #L375 - L377 were not covered by tests
part_state_dict[key] = weight
return part_state_dict


def load_state_dict(
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None, device="cpu"
):
Expand All @@ -343,21 +403,36 @@
if metadata.get("format", "np") == "pd":
raise ValueError("Currently unsupport paddle weights file, use numpy instead.")
if metadata.get("format", "np") == "np":
thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1"))
state_dict = {}
with safe_open(checkpoint_file, framework="np") as f:
for key in f.keys():
if fliter_dict_keys is not None and key not in fliter_dict_keys:
continue
py_safe_slice_ = f.get_slice(key)
if key in tensor_parallel_split_mapping:
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
else:
weight = py_safe_slice_[:]
if device == "expected":
with device_guard():
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
state_dict[key] = weight
if thread_num <= 1:
with safe_open(checkpoint_file, framework="np") as f:
state_dict = _load_part_state_dict(
list(f.keys()),
checkpoint_file,
tensor_parallel_split_mapping,
fliter_dict_keys,
device,
)
else:
# Load state dict in multi-thread to speed up loading
with safe_open(checkpoint_file, framework="np") as f:
keys_groups = _split_keys_evenly(list(f.keys()), thread_num)
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor:
future_to_key = {

Check warning on line 422 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L419-L422

Added lines #L419 - L422 were not covered by tests
executor.submit(
_load_part_state_dict,
keys,
checkpoint_file,
tensor_parallel_split_mapping,
fliter_dict_keys,
device,
): keys
for keys in keys_groups
}
for future in concurrent.futures.as_completed(future_to_key):
result = future.result()
state_dict.update(result)

Check warning on line 435 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L433-L435

Added lines #L433 - L435 were not covered by tests

if device == "cpu":
for k in list(state_dict.keys()):
Expand Down Expand Up @@ -1963,7 +2038,6 @@

if config.quantization_config.is_weight_quantize():
filter_dict_keys = None

state_dict = load_state_dict(
shard_file, tp_actions if pre_tensor_parallel_split else None, filter_dict_keys
)
Expand Down Expand Up @@ -2279,7 +2353,6 @@
else:
raise ValueError(f"Unexpected file: {resolved_archive_file} for weight conversion.")
# load pt weights early so that we know which dtype to init the model under

if not is_sharded and state_dict is None:
# 4. loading non-sharded ckpt from the state dict
if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model_state.pdparams"):
Expand Down
Loading