diff --git a/paddlenlp/trainer/utils/ckpt_converter.py b/paddlenlp/trainer/utils/ckpt_converter.py index 23f085e18f44..198cd8df35c9 100644 --- a/paddlenlp/trainer/utils/ckpt_converter.py +++ b/paddlenlp/trainer/utils/ckpt_converter.py @@ -19,17 +19,17 @@ from typing import List, Union import paddle -from paddle.distributed.checkpoint.load_state_dict import ( +from paddle.distributed.fleet.utils.log_util import logger +from paddle.distributed.flex_checkpoint.dcp.load_state_dict import ( _load_state_dict, get_rank_to_read_files, ) -from paddle.distributed.checkpoint.metadata import ( +from paddle.distributed.flex_checkpoint.dcp.metadata import ( LocalTensorIndex, LocalTensorMetadata, Metadata, ) -from paddle.distributed.checkpoint.utils import flatten_state_dict -from paddle.distributed.fleet.utils.log_util import logger +from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict MODEL_WEIGHT_SUFFIX = ".pdparams" OPTIMIZER_WEIGHT_SUFFIX = ".pdopt" @@ -206,7 +206,7 @@ def gen_metadata_and_prepare_source_state_dict(self): global_offset = [0] * self.tp_degree for item in shard_info: tp_rank = item[0]["tp_rank"] - state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank) + state_name_with_tp_rank = state_name + "_tp" + f"{tp_rank:02d}" local_tensor_meta_data = LocalTensorMetadata((global_offset[tp_rank],), item[1], item[2]) local_tensor_index = LocalTensorIndex(state_name_with_tp_rank, (global_offset[tp_rank],)) global_offset[tp_rank] += item[1][0] @@ -225,7 +225,7 @@ def gen_metadata_and_prepare_source_state_dict(self): renamed_state_dict = {} (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) for state_name, state_value in state_dict.items(): - state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank) + state_name_with_tp_rank = state_name + "_tp" + f"{tp_rank:02d}" renamed_state_dict[state_name_with_tp_rank] = state_value source_state_dict_for_merge_sharding[file_name] = renamed_state_dict @@ -235,7 +235,7 @@ def gen_metadata_and_prepare_source_state_dict(self): sharding_metas_keys = [] for i in range(self.tp_degree): for j in range(self.pp_degree): - sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j)) + sharding_metas_keys.append(f"tp{i:02d}_pp{j:02d}") for key in sharding_metas_keys: param_meta = self.model_meta["sharding_metas"][key]["param_meta"] for param_name, param_shape_and_dtype in param_meta.items(): @@ -253,7 +253,7 @@ def gen_metadata_and_prepare_source_state_dict(self): all_param_meta = {} for i in range(self.tp_degree): for j in range(self.pp_degree): - key = "tp{:02d}_pp{:02d}".format(i, j) + key = f"tp{i:02d}_pp{j:02d}" param_meta = self.model_meta["sharding_metas"][key]["param_meta"] for param_name, param_shape_and_dtype in param_meta.items(): all_param_meta[param_name] = param_shape_and_dtype @@ -269,7 +269,7 @@ def gen_metadata_and_prepare_source_state_dict(self): with paddle.base.dygraph.guard(place=paddle.CPUPlace()): for key in cur_rank_need_load_model_state_keys: for tp_rank in range(self.tp_degree): - tp_rank_suffix = "_tp{:02d}".format(tp_rank) + tp_rank_suffix = f"_tp{tp_rank:02d}" optimizer_state_dict[key + ".moment1" + tp_rank_suffix] = paddle.zeros( (param_flattened_shapes[key],), "float32" ) @@ -353,7 +353,7 @@ def gen_metadata_and_prepare_source_state_dict(self): else: concat_optimier_state_dict[opt_state_name_removed_tp_rank] = tp_tensors[0] - fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" + fake_file_name = f"{self.cur_rank:02d}" + ".distcp" local_tensor_meta_data = {} local_tensor_index = {} for k, v in concat_optimier_state_dict.items(): @@ -472,7 +472,7 @@ def gen_metadata_and_prepare_source_state_dict(self): reshaped_v = v.reshape(shape) target_state_dict[k] = reshaped_v - fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" + fake_file_name = f"{self.cur_rank:02d}" + ".distcp" local_tensor_meta_data = {} local_tensor_index = {} for k, v in target_state_dict.items(): @@ -911,7 +911,7 @@ def rename_using_model_meta(self, file_name): self.model_meta = json.load(file) (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) - dist_strategy_key = "tp" + "{:02d}".format(tp_rank) + "_" + "pp" + "{:02d}".format(pp_rank) + dist_strategy_key = "tp" + f"{tp_rank:02d}" + "_" + "pp" + f"{pp_rank:02d}" # Map model weight names to their corresponding names of master_weights in the optimizer state. if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"]