Skip to content

Commit

Permalink
Update checkpoint_converter.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xingmingyyj authored Jul 31, 2024
1 parent 3bdedab commit 5fc66c2
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions paddlenlp/trainer/checkpoint_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,10 @@ def gen_matadata_for_optimizer(self):
for k, v in state_dict.items():
for prame_name, opt_name in structure_name_mapping.items():
if opt_name in k:
new_key = k.replace(opt_name, prame_name) + "_tp" + "{:02d}".format(distributed_rank[0])
else:
new_key = k.replace(opt_name, prame_name)
if v.shape[0] > 1:
new_key = k.replace(opt_name, prame_name) + "_tp" + "{:02d}".format(distributed_rank[0])

Check warning on line 210 in paddlenlp/trainer/checkpoint_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/checkpoint_converter.py#L205-L210

Added lines #L205 - L210 were not covered by tests
else:
new_key = k.replace(opt_name, prame_name)
renamed_state_dict[new_key] = v

Check warning on line 213 in paddlenlp/trainer/checkpoint_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/checkpoint_converter.py#L212-L213

Added lines #L212 - L213 were not covered by tests
# Calculate the local_shape.
cur_rank_sharded_tensor_infos[(new_key, file)] = [v.shape, str(v.dtype)]

Check warning on line 215 in paddlenlp/trainer/checkpoint_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/checkpoint_converter.py#L215

Added line #L215 was not covered by tests
Expand Down

0 comments on commit 5fc66c2

Please sign in to comment.