Skip to content

Commit

Permalink
fix split_param for expert parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Jan 26, 2025
1 parent 37f3be1 commit 8bed006
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,11 @@ def load_resolved_archive_file(
)
)
if has_master_weights:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
# for moe gate with float32 dtype.
key_name = "_".join([static_name, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])

Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,8 @@ def unified_optimizer_into_shards(
tp_size = tp_group.nranks
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0

no_sync_kname = []
if args.use_expert_parallel:
no_sync_kname = []
for k, v in state_dict.items():
if getattr(state_dict[k], "no_sync", False):
no_sync_kname.append(k)
Expand Down

0 comments on commit 8bed006

Please sign in to comment.