Skip to content

Commit

Permalink
add filter_sync_parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Feb 11, 2025
1 parent 26e51d7 commit c0042fb
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def load_resolved_archive_file(
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
# for moe gate with float32 dtype.
# for parameters with float32 dtype, no need to have fp32 master weights.
key_name = "_".join([static_name, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
Expand Down
46 changes: 6 additions & 40 deletions paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
FP32_MASTER,
UnifiedCheckpointOption,
filter_params,
filter_sync_parameters,
gather_sharded_object,
generate_base_static_name,
get_expected_state_dict,
Expand Down Expand Up @@ -218,25 +219,9 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
for key in list(master_weights.keys()):
master_weights[static2struct_name_mappings[key]] = master_weights.pop(key)

no_sync_kname = []
model_state_dict = get_expected_state_dict(model)
for k, v in model_state_dict.items():
if getattr(v, "no_sync", False):
no_sync_kname.append(k)

hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
if self.args.use_expert_parallel:
for k in list(optim_state_dict.keys()):
model_k = k.split("/")[0]
if dp_rank > 0 and model_k not in no_sync_kname:
optim_state_dict.pop(k)
if master_weights is not None:
for k in list(master_weights.keys()):
model_k = k.split("/")[0]
if dp_rank > 0 and model_k not in no_sync_kname:
master_weights.pop(k)
model_state_dict = get_expected_state_dict(model)
filter_sync_parameters(model_state_dict, optim_state_dict, master_weights, is_model_weight=False)

optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix)
Expand Down Expand Up @@ -518,12 +503,7 @@ def unified_checkpoint_into_shards(

if args.use_expert_parallel:
# ignore saving `no_sync=False` tensors when using expert_parallel under dp_rank > 0.
hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
for key in list(state_dict.keys()):
if dp_rank > 0 and not getattr(state_dict[key], "no_sync", False):
state_dict.pop(key)
filter_sync_parameters(state_dict, is_model_weight=True)

if config_to_save.tensor_parallel_degree > 1:
if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM):
Expand Down Expand Up @@ -631,25 +611,11 @@ def unified_optimizer_into_shards(
filter_master_keys = filter_params(model, master_weights, args, is_optimizer=True)
filter_optim_keys = filter_params(model, optim_state_dict, args, is_optimizer=True)

hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
dp_group = hcg.get_data_parallel_group()
tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()
tp_size = tp_group.nranks
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0

if args.use_expert_parallel:
no_sync_kname = []
for k, v in state_dict.items():
if getattr(state_dict[k], "no_sync", False):
no_sync_kname.append(k)
for key in list(optim_state_dict.keys()):
model_key = key.split("/")[0]
if dp_rank > 0 and model_key not in no_sync_kname:
optim_state_dict.pop(key)
if master_weights is not None:
for key in list(master_weights.keys()):
if dp_rank > 0 and key not in no_sync_kname:
master_weights.pop(key)
filter_sync_parameters(state_dict, optim_state_dict, master_weights, is_model_weight=False)

if tp_size > 1:
# get tp_actions
Expand Down
28 changes: 28 additions & 0 deletions paddlenlp/trainer/unified_checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,3 +758,31 @@ def save_config(model_to_save):
# save generation config
if model_to_save.can_generate():
model_to_save.generation_config.save_pretrained(save_directory)


def filter_sync_parameters(model_state_dict, optim_state_dict=None, master_weights=None, is_model_weight=True):
"""Filter sync parameters under expert parallel mode."""

hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0

if is_model_weight:
for key in list(model_state_dict.keys()):
if dp_rank > 0 and not getattr(model_state_dict[key], "no_sync", False):
model_state_dict.pop(key)
else:
no_sync_kname = []
for k, v in model_state_dict.items():
if getattr(v, "no_sync", False):
no_sync_kname.append(k)

for key in list(optim_state_dict.keys()):
model_key = key.split("/")[0]
if dp_rank > 0 and model_key not in no_sync_kname:
optim_state_dict.pop(key)

if master_weights is not None:
for key in list(master_weights.keys()):
if dp_rank > 0 and key not in no_sync_kname:
master_weights.pop(key)

0 comments on commit c0042fb

Please sign in to comment.