Skip to content

Commit

Permalink
fix optimizer file name
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Sep 12, 2024
1 parent deb1072 commit f5b02c3
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1942,7 +1942,7 @@ def get_sharded_file_name(args, file_name, is_optimizer=False):
else:
hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
size = args.world_size if args.use_expert_parallel else dp_group.nranks
size = dp_group.nranks if not args.use_expert_parallel else 1

Check warning on line 1945 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1945

Added line #L1945 was not covered by tests
shard_file = file_name.replace(
".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams"
)
Expand Down Expand Up @@ -2246,7 +2246,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):

if len(tp_actions) > 0:
for x in tp_actions.keys():
logger.warning(f"key <{x}> need to merge tensor parallel but we can't find in model state.")
logger.debug(f"key <{x}> need to merge tensor parallel but we can't find in model state.")

Check warning on line 2249 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L2249

Added line #L2249 was not covered by tests

return state_dict_to_save

Expand Down

0 comments on commit f5b02c3

Please sign in to comment.