Skip to content

Commit

Permalink
fix optimizer file name
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Sep 12, 2024
1 parent deb1072 commit cc5463c
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,26 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir):
for key in list(master_weights.keys()):
master_weights[static2struct_name_mappings[key]] = master_weights.pop(key)

no_sync_kname = []
model_state_dict = get_expected_state_dict(model)
for k, v in model_state_dict.items():
if getattr(v, "no_sync", False):
no_sync_kname.append(k)

hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
if self.args.use_expert_parallel:
for k in list(optim_state_dict.keys()):
model_k = k.split("/")[0]
if dp_rank > 0 and model_k not in no_sync_kname:
optim_state_dict.pop(k)
if master_weights is not None:
for k in list(master_weights.keys()):
model_k = k.split("/")[0]
if dp_rank > 0 and model_k not in no_sync_kname:
master_weights.pop(k)

optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix)

Expand Down Expand Up @@ -464,7 +484,10 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):
key_name = key.split("/")
static_name = struct2static_name_mappings[key_name[0]]
if has_master_weights:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
with device_guard():
Expand Down Expand Up @@ -566,7 +589,7 @@ def load_unified_optimizer(self, args, model, optimizer, resume_from_checkpoint)
return optim_state_dict

if "ignore_merge_optimizer" in self.args.unified_checkpoint_config:
if self.args.data_parallel_rank == 0:
if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel:
returned_optim_state_dict = self.load_non_merge_optimizer(
model,
optimizer,
Expand Down Expand Up @@ -1942,7 +1965,7 @@ def get_sharded_file_name(args, file_name, is_optimizer=False):
else:
hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
size = args.world_size if args.use_expert_parallel else dp_group.nranks
size = dp_group.nranks if not args.use_expert_parallel else 1
shard_file = file_name.replace(
".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams"
)
Expand Down Expand Up @@ -2246,7 +2269,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):

if len(tp_actions) > 0:
for x in tp_actions.keys():
logger.warning(f"key <{x}> need to merge tensor parallel but we can't find in model state.")
logger.debug(f"key <{x}> need to merge tensor parallel but we can't find in model state.")

return state_dict_to_save

Expand Down

0 comments on commit cc5463c

Please sign in to comment.