Skip to content

Commit

Permalink
Set DistModel state_dict keys to structure_names (#60478)
Browse files Browse the repository at this point in the history
* exclude xpu

* check structure name mapping

* test pp

* polish

* support dynamic save static load

* support dygraph save static load

* polish

* polish

* use structured_name as key in DistModel state_dict

* polish

* polish

* fix checkpoint path conflict

* test get_rank_to_files

* static save dynamic load test
  • Loading branch information
pangengzheng authored Jan 5, 2024
1 parent 116c892 commit 57feb0a
Show file tree
Hide file tree
Showing 11 changed files with 414 additions and 41 deletions.
22 changes: 21 additions & 1 deletion python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,12 @@ def __init__(
):
self._feed_name_list = []
self._inner_strategy = self.__convert_strategy(strategy)
self._structured_to_parameter_name = {
k: v.name for k, v in layer.state_dict().items()
}
self._parameter_to_structured_name = {
v: k for k, v in self._structured_to_parameter_name.items()
}
self._engine = Engine(
layer, loss, optimizer, metrics, strategy=self._inner_strategy
)
Expand Down Expand Up @@ -1257,6 +1263,15 @@ def state_dict(self, mode="all"):
mode=self._engine._mode
).state_dict(mode)
dist_state_dict = self._build_distributed_state_dict(local_state_dict)
mapping_names = [
self._parameter_to_structured_name[k]
if k in self._parameter_to_structured_name
else k
for k in dist_state_dict.keys()
]
dist_state_dict = dict(
zip(mapping_names, list(dist_state_dict.values()))
)
return dist_state_dict

def _build_distributed_state_dict(self, local_state_dict):
Expand Down Expand Up @@ -1331,7 +1346,12 @@ def set_state_dict(self, state_dict):
].process_mesh or check_placements_equal(
v.placements, cur_v.placements
), f"process_mesh:{v.process_mesh} != {cur_v.process_mesh} or placements:{v.placements} != {cur_v.placements} not match"
local_state_dict[k] = v._local_value()
param_name = (
self._structured_to_parameter_name[k]
if k in self._structured_to_parameter_name
else k
)
local_state_dict[param_name] = v._local_value()
dist_main_program.set_state_dict(local_state_dict)


Expand Down
62 changes: 39 additions & 23 deletions python/paddle/distributed/checkpoint/load_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import copy
import os
from dataclasses import dataclass
from typing import Tuple
from typing import Dict, Tuple

import paddle
from paddle.distributed.communication.group import is_initialized
Expand All @@ -37,14 +37,36 @@ class ReadItem:
lengths: Tuple[int]


def get_rank_to_files(path, state_dict, process_group, use_dist):
PATH_TO_CHECKPOINT_FILES: Dict[str, Tuple[list, list]] = {}


def get_checkpoint_files(path, use_cache=True):
global PATH_TO_CHECKPOINT_FILES
if use_cache and path in PATH_TO_CHECKPOINT_FILES:
return PATH_TO_CHECKPOINT_FILES[path]
accessible_files = os.listdir(path)
metadata_files = [
file for file in accessible_files if file.endswith(".metadata")
]
assert (
len(metadata_files) > 0
), f"No metadata file found in the checkpoint directory:{path}."
local_data_files = [
file for file in accessible_files if file.endswith(".distcp")
]
assert (
len(local_data_files) > 0
), f"No data file found in the checkpoint directory:{path}."
if use_cache:
PATH_TO_CHECKPOINT_FILES[path] = (metadata_files, local_data_files)
return (metadata_files, local_data_files)


def get_rank_to_files(path, state_dict, process_group, use_dist):
"""
Get the mapping of rank to its accessible files.
"""
metadata_files, local_data_files = get_checkpoint_files(path)
# The neccesary files to be read
tensor_key_list = []
necessary_files = []
Expand All @@ -62,12 +84,10 @@ def get_rank_to_files(path, state_dict, process_group, use_dist):
logger.warning(
f"No necessary data files found in the checkpoint directory:{path}. Please check the metadata_files:{metadata_files}"
)
return {}
missing_keys = set(state_dict.keys())
return {}, missing_keys

# allgather all accessible files
local_data_files = [
file for file in accessible_files if file.endswith(".distcp")
]
global_data_files = []
if use_dist:
paddle.distributed.all_gather_object(
Expand Down Expand Up @@ -101,12 +121,16 @@ def get_rank_to_files(path, state_dict, process_group, use_dist):
]
rank_to_files[rank] = local_files
logger.debug(f"mapping rank_to_files:{rank_to_files}")
return rank_to_files
return rank_to_files, missing_keys


def get_local_load_files(rank_to_files):
"""
Load files in a load-balanced manner.
Args:
rank_to_files (dict): mapping from rank to files.
Example:
Case1: all ranks access the same data files
rank_to_files = {rank0:[0_0.distcp, 1_0.distcp, 2_0.distcp, 3_0.distcp], rank1:[0_0.distcp, 1_0.distcp, 2_0.distcp, 3_0.distcp]}
Expand Down Expand Up @@ -196,13 +220,7 @@ def update(rank_to_read_files, rank_to_not_read_files, rank_file):

def get_load_infos(path, local_load_files, process_group, use_dist):
load_info = {}
accessible_files = os.listdir(path)
metadata_files = [
file for file in accessible_files if file.endswith(".metadata")
]
assert (
len(metadata_files) > 0
), "No metadata file found in the checkpoint directory:{path}."
metadata_files, _ = get_checkpoint_files(path)
for metadata_file in metadata_files:
metadata = paddle.load(os.path.join(path, metadata_file))
for local_tensor_index, file_name in metadata.storage_metadata.items():
Expand Down Expand Up @@ -277,14 +295,8 @@ def not_overlap(


def get_read_items(path, state_dict, process_group, use_dist):
accessible_files = os.listdir(path)
metadata_files = [
file for file in accessible_files if file.endswith(".metadata")
]
assert (
len(metadata_files) > 0
), "No metadata file found in the checkpoint directory:{path}."
storage_state_dict_metadata = {}
metadata_files, _ = get_checkpoint_files(path)
for metadata_file in metadata_files:
metadata = paddle.load(os.path.join(path, metadata_file))
for (
Expand Down Expand Up @@ -410,7 +422,7 @@ def load_state_dict(
for val in flat_state_dict.values():
assert isinstance(
val, paddle.Tensor
), f"Only support dygraph Tensor now, but is {val}"
), f"The value of state_dict should be a paddle.Tensor, but got: {val}."

use_dist = True if paddle.distributed.get_world_size() > 1 else False

Expand All @@ -422,9 +434,13 @@ def load_state_dict(
# sync to avoid some ranks not write path yet
paddle.distributed.barrier(process_group)

rank_to_files = get_rank_to_files(
rank_to_files, missing_keys = get_rank_to_files(
path, flat_state_dict, process_group, use_dist
)
if len(missing_keys) > 0:
logger.warning(
f"The following keys:{missing_keys} are not found in checkpoint path: {path}."
)
if len(rank_to_files) <= 0:
return
local_load_files = get_local_load_files(rank_to_files)
Expand Down
5 changes: 4 additions & 1 deletion python/paddle/distributed/checkpoint/save_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def save_state_dict(
for val in flat_state_dict.values():
assert isinstance(
val, paddle.Tensor
), "Only support dygraph Tensor now, support static DistributedTensor later"
), f"The value of state_dict should be a paddle.Tensor, but got: {val}."

if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
Expand Down Expand Up @@ -188,6 +188,8 @@ def save_state_dict(
if local_shape is None or global_offset is None:
continue
local_tensor = val._local_value()
# Note: The local_tensor must keep the same name with the original tensor. Otherwise, the StructuredToParameterName@@ mapping will be wrong.
local_tensor.name = val.name
else:
local_shape = tuple(val.shape)
global_offset = (
Expand All @@ -203,6 +205,7 @@ def save_state_dict(
local_storage_metadata[
LocalTensorIndex(key, tuple(global_offset))
] = file_name

global_state_dict_metadata = []
global_storage_metadata = []
global_flatten_mapping = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os

import numpy as np
from auto_parallel.hybrid_strategy.save_state_dict import (
from auto_parallel.hybrid_strategy.semi_auto_save_state_dict import (
get_global_state_dict,
)

Expand Down
Loading

0 comments on commit 57feb0a

Please sign in to comment.