Skip to content

Commit

Permalink
fix split_param (#9817)
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay authored Jan 24, 2025
1 parent 2f85a64 commit 3967f76
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 28 deletions.
32 changes: 17 additions & 15 deletions paddlenlp/trainer/unified_checkpoint/load_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,6 @@ def _remove_unused_keys(


def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False):
# Special process with split param.
if is_sharding_split_param_mode(args):
returned_optim_state_dict = load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint)
return returned_optim_state_dict

# init and get optimizer LR_Scheduler
returned_optim_state_dict = nested_copy(optimizer.state_dict())

if not safe_serialization:
index_filename, index_filename_master_weights = (
PADDLE_OPTIMIZER_INDEX_NAME,
Expand All @@ -165,6 +157,23 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
else:
index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME

with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f:
index = json.loads(f.read())

ckpt_quant_stage = "O0"
if "ckpt_quant_stage" in index:
ckpt_quant_stage = index["ckpt_quant_stage"]

# Special process with split param.
if is_sharding_split_param_mode(args):
returned_optim_state_dict = load_unified_optimizer_split_param(
args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage
)
return returned_optim_state_dict

# init and get optimizer LR_Scheduler
returned_optim_state_dict = nested_copy(optimizer.state_dict())

resolved_archive_file, sharded_metadata = get_optimizer_shard_files(
optimizer_path=resume_from_checkpoint,
index_filename=os.path.join(resume_from_checkpoint, index_filename),
Expand All @@ -184,13 +193,6 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")

with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f:
index = json.loads(f.read())

ckpt_quant_stage = "O0"
if "ckpt_quant_stage" in index:
ckpt_quant_stage = index["ckpt_quant_stage"]

# update has_master_weights and index_filename_master_weights
# 1. if the master weight exists, only has_master_weights is set True and loaded when needed
# 2. if master weight does not exist, convert model weight to master weight when needed
Expand Down
71 changes: 59 additions & 12 deletions paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,25 @@
get_expected_state_dict,
get_optimizer_shard_files,
mapping_optimizer_tp_actions,
update_master_weight_status,
)

__all__ = ["gather_splited_param_for_optimizer", "load_unified_optimizer_split_param"]


def merge_splited_param(
state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False
state_dict,
partial_tensor_list,
param_shape_info,
send_table,
recv_table,
is_master_weights=False,
ckpt_quant_stage="O0",
):
"""Merge the splited param in sharding group."""
global_rank = dist.get_rank()
for key in list(state_dict.keys()):
if state_dict[key].numel().item() == 1: # for example: beta1, beta2
if int(state_dict[key].numel()) == 1: # for example: beta1, beta2
continue

static_name = key if is_master_weights else generate_base_static_name(key)[0]
Expand Down Expand Up @@ -89,10 +96,21 @@ def merge_splited_param(
)
dist.stream.send(tensor, dst=recv_rank)
state_dict.pop(key)

if ckpt_quant_stage != "O0":
for key in list(state_dict.keys()):
if int(state_dict[key].numel()) == 1: # for example: beta1, beta2
static_name = key if is_master_weights else generate_base_static_name(key)[0]
if static_name in partial_tensor_list:
recv_rank = recv_table[static_name]
send_info = send_table[static_name]
if global_rank != recv_rank:
state_dict.pop(key)

return state_dict


def gather_splited_param_for_optimizer(optimizer):
def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"):
hcg = fleet.get_hybrid_communicate_group()
sharding_group = hcg.get_sharding_parallel_group()
global_rank = dist.get_rank()
Expand Down Expand Up @@ -127,7 +145,7 @@ def gather_splited_param_for_optimizer(optimizer):
for key in list(optim_state_dict.keys()):
static_name, _ = generate_base_static_name(key)
if static_name in param_slice_info.keys():
if optim_state_dict[key].numel().item() == 1: # for example: beta1, beta2
if int(optim_state_dict[key].numel()) == 1: # for example: beta1, beta2
continue
begin, end = param_slice_info[static_name]
shape, numel, _, _ = param_shape_info[static_name]
Expand All @@ -149,13 +167,15 @@ def gather_splited_param_for_optimizer(optimizer):
recv_table[key] = sharding_ranklist[0][0] # which sharding_rank to recv the splited tensor
send_table[key] = [(rank, begin, end) for rank, begin, end in sharding_ranklist]

merge_splited_param(optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False)
merge_splited_param(
optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False, ckpt_quant_stage
)
if master_weights is not None:
merge_splited_param(master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True)
return optim_state_dict, master_weights


def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint):
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
returned_optim_state_dict = nested_copy(optimizer.state_dict())

index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME
Expand Down Expand Up @@ -208,6 +228,10 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")

has_master_weights, index_filename_master_weights = update_master_weight_status(
args, optimizer, has_master_weights, safe_serialization=True
)

if has_master_weights:
returned_optim_state_dict["master_weights"] = {}
resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files(
Expand All @@ -217,7 +241,9 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
if len(resolved_archive_file_mw) > 1:
resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards")

def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False):
def load_resolved_archive_file(
resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False, ckpt_quant_stage="O0"
):
returned_state_dict = {}

if model.config.tensor_parallel_degree > 1:
Expand All @@ -232,24 +258,38 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]):
continue
if model.config.tensor_parallel_degree > 1:
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="cpu")
state_dict = load_state_dict(
shard_file,
tp_actions,
expected_keys,
device="cpu",
ckpt_quant_stage=ckpt_quant_stage,
)
else:
state_dict = load_state_dict(shard_file, None, expected_keys, device="cpu")
state_dict = load_state_dict(
shard_file,
None,
expected_keys,
device="cpu",
ckpt_quant_stage=ckpt_quant_stage,
)
returned_state_dict.update(state_dict)
del state_dict
gc.collect()

return returned_state_dict

# get tp params
state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim)
state_dict_optim = load_resolved_archive_file(
resolved_archive_file, sharded_metadata, expected_keys_optim, ckpt_quant_stage=ckpt_quant_stage
)

# need to split param for different sharding rank, maybe need to deal with oom issue.
for key in list(state_dict_optim.keys()):
key_name = key.split("/")
static_name = struct2static_name_mappings.get(key_name[0], None)

if state_dict_optim[key].numel().item() > 1:
if int(state_dict_optim[key].numel()) > 1:
begin, end = param_slice_info[static_name]
shape, numel, index, padded_size = param_shape_info[static_name]
state_dict_optim[key] = state_dict_optim[key].reshape([-1])
Expand Down Expand Up @@ -284,7 +324,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected

for key in list(state_dict_master_weight.keys()):
static_name = struct2static_name_mappings.get(key, None)
if state_dict_master_weight[key].numel().item() > 1:
if int(state_dict_master_weight[key].numel()) > 1:
begin, end = param_slice_info[static_name]
shape, numel, index, padded_size = param_shape_info[static_name]
state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1])
Expand All @@ -303,6 +343,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
paddle.framework._current_expected_place(), False
)
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)

# master weight cast (only in remove_master_weight)
if returned_optim_state_dict["master_weights"][static_name].dtype != paddle.float32:
returned_optim_state_dict["master_weights"][static_name] = paddle.cast(
returned_optim_state_dict["master_weights"][static_name], dtype=paddle.float32
)

returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])

return returned_optim_state_dict
4 changes: 3 additions & 1 deletion paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,9 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
return

if is_sharding_split_param_mode(self.args):
optim_state_dict, master_weights = gather_splited_param_for_optimizer(optimizer)
optim_state_dict, master_weights = gather_splited_param_for_optimizer(
optimizer, self.args.ckpt_quant_stage if "quant_reach_limit" not in infohub else "O0"
)
else:
optim_state_dict = nested_copy(optimizer.state_dict())
master_weights = None
Expand Down

0 comments on commit 3967f76

Please sign in to comment.