diff --git a/paddlenlp/trainer/unified_checkpoint/load_local.py b/paddlenlp/trainer/unified_checkpoint/load_local.py index c2ccbe4d240b..4e37f44b4d19 100644 --- a/paddlenlp/trainer/unified_checkpoint/load_local.py +++ b/paddlenlp/trainer/unified_checkpoint/load_local.py @@ -149,14 +149,6 @@ def _remove_unused_keys( def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): - # Special process with split param. - if is_sharding_split_param_mode(args): - returned_optim_state_dict = load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint) - return returned_optim_state_dict - - # init and get optimizer LR_Scheduler - returned_optim_state_dict = nested_copy(optimizer.state_dict()) - if not safe_serialization: index_filename, index_filename_master_weights = ( PADDLE_OPTIMIZER_INDEX_NAME, @@ -165,6 +157,23 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin else: index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME + with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f: + index = json.loads(f.read()) + + ckpt_quant_stage = "O0" + if "ckpt_quant_stage" in index: + ckpt_quant_stage = index["ckpt_quant_stage"] + + # Special process with split param. + if is_sharding_split_param_mode(args): + returned_optim_state_dict = load_unified_optimizer_split_param( + args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage + ) + return returned_optim_state_dict + + # init and get optimizer LR_Scheduler + returned_optim_state_dict = nested_copy(optimizer.state_dict()) + resolved_archive_file, sharded_metadata = get_optimizer_shard_files( optimizer_path=resume_from_checkpoint, index_filename=os.path.join(resume_from_checkpoint, index_filename), @@ -184,13 +193,6 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin if len(resolved_archive_file) > 1: resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards") - with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f: - index = json.loads(f.read()) - - ckpt_quant_stage = "O0" - if "ckpt_quant_stage" in index: - ckpt_quant_stage = index["ckpt_quant_stage"] - # update has_master_weights and index_filename_master_weights # 1. if the master weight exists, only has_master_weights is set True and loaded when needed # 2. if master weight does not exist, convert model weight to master weight when needed diff --git a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py index fda80fca0a61..19e88b402bcf 100644 --- a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py +++ b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py @@ -36,18 +36,25 @@ get_expected_state_dict, get_optimizer_shard_files, mapping_optimizer_tp_actions, + update_master_weight_status, ) __all__ = ["gather_splited_param_for_optimizer", "load_unified_optimizer_split_param"] def merge_splited_param( - state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False + state_dict, + partial_tensor_list, + param_shape_info, + send_table, + recv_table, + is_master_weights=False, + ckpt_quant_stage="O0", ): """Merge the splited param in sharding group.""" global_rank = dist.get_rank() for key in list(state_dict.keys()): - if state_dict[key].numel().item() == 1: # for example: beta1, beta2 + if int(state_dict[key].numel()) == 1: # for example: beta1, beta2 continue static_name = key if is_master_weights else generate_base_static_name(key)[0] @@ -89,10 +96,21 @@ def merge_splited_param( ) dist.stream.send(tensor, dst=recv_rank) state_dict.pop(key) + + if ckpt_quant_stage != "O0": + for key in list(state_dict.keys()): + if int(state_dict[key].numel()) == 1: # for example: beta1, beta2 + static_name = key if is_master_weights else generate_base_static_name(key)[0] + if static_name in partial_tensor_list: + recv_rank = recv_table[static_name] + send_info = send_table[static_name] + if global_rank != recv_rank: + state_dict.pop(key) + return state_dict -def gather_splited_param_for_optimizer(optimizer): +def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"): hcg = fleet.get_hybrid_communicate_group() sharding_group = hcg.get_sharding_parallel_group() global_rank = dist.get_rank() @@ -127,7 +145,7 @@ def gather_splited_param_for_optimizer(optimizer): for key in list(optim_state_dict.keys()): static_name, _ = generate_base_static_name(key) if static_name in param_slice_info.keys(): - if optim_state_dict[key].numel().item() == 1: # for example: beta1, beta2 + if int(optim_state_dict[key].numel()) == 1: # for example: beta1, beta2 continue begin, end = param_slice_info[static_name] shape, numel, _, _ = param_shape_info[static_name] @@ -149,13 +167,15 @@ def gather_splited_param_for_optimizer(optimizer): recv_table[key] = sharding_ranklist[0][0] # which sharding_rank to recv the splited tensor send_table[key] = [(rank, begin, end) for rank, begin, end in sharding_ranklist] - merge_splited_param(optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False) + merge_splited_param( + optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False, ckpt_quant_stage + ) if master_weights is not None: merge_splited_param(master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True) return optim_state_dict, master_weights -def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint): +def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"): returned_optim_state_dict = nested_copy(optimizer.state_dict()) index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME @@ -208,6 +228,10 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check if len(resolved_archive_file) > 1: resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards") + has_master_weights, index_filename_master_weights = update_master_weight_status( + args, optimizer, has_master_weights, safe_serialization=True + ) + if has_master_weights: returned_optim_state_dict["master_weights"] = {} resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( @@ -217,7 +241,9 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check if len(resolved_archive_file_mw) > 1: resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards") - def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False): + def load_resolved_archive_file( + resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False, ckpt_quant_stage="O0" + ): returned_state_dict = {} if model.config.tensor_parallel_degree > 1: @@ -232,9 +258,21 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): continue if model.config.tensor_parallel_degree > 1: - state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="cpu") + state_dict = load_state_dict( + shard_file, + tp_actions, + expected_keys, + device="cpu", + ckpt_quant_stage=ckpt_quant_stage, + ) else: - state_dict = load_state_dict(shard_file, None, expected_keys, device="cpu") + state_dict = load_state_dict( + shard_file, + None, + expected_keys, + device="cpu", + ckpt_quant_stage=ckpt_quant_stage, + ) returned_state_dict.update(state_dict) del state_dict gc.collect() @@ -242,14 +280,16 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected return returned_state_dict # get tp params - state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim) + state_dict_optim = load_resolved_archive_file( + resolved_archive_file, sharded_metadata, expected_keys_optim, ckpt_quant_stage=ckpt_quant_stage + ) # need to split param for different sharding rank, maybe need to deal with oom issue. for key in list(state_dict_optim.keys()): key_name = key.split("/") static_name = struct2static_name_mappings.get(key_name[0], None) - if state_dict_optim[key].numel().item() > 1: + if int(state_dict_optim[key].numel()) > 1: begin, end = param_slice_info[static_name] shape, numel, index, padded_size = param_shape_info[static_name] state_dict_optim[key] = state_dict_optim[key].reshape([-1]) @@ -284,7 +324,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected for key in list(state_dict_master_weight.keys()): static_name = struct2static_name_mappings.get(key, None) - if state_dict_master_weight[key].numel().item() > 1: + if int(state_dict_master_weight[key].numel()) > 1: begin, end = param_slice_info[static_name] shape, numel, index, padded_size = param_shape_info[static_name] state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1]) @@ -303,6 +343,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected paddle.framework._current_expected_place(), False ) returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key) + + # master weight cast (only in remove_master_weight) + if returned_optim_state_dict["master_weights"][static_name].dtype != paddle.float32: + returned_optim_state_dict["master_weights"][static_name] = paddle.cast( + returned_optim_state_dict["master_weights"][static_name], dtype=paddle.float32 + ) + returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) return returned_optim_state_dict diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 0cb38bec94bb..2e453c54e80d 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -344,7 +344,9 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): return if is_sharding_split_param_mode(self.args): - optim_state_dict, master_weights = gather_splited_param_for_optimizer(optimizer) + optim_state_dict, master_weights = gather_splited_param_for_optimizer( + optimizer, self.args.ckpt_quant_stage if "quant_reach_limit" not in infohub else "O0" + ) else: optim_state_dict = nested_copy(optimizer.state_dict()) master_weights = None