Skip to content

Commit

Permalink
change save type (#6044)
Browse files Browse the repository at this point in the history
  • Loading branch information
lugimzzz authored May 29, 2023
1 parent 2b53d21 commit 1d776a6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions paddlenlp/prompt/prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from functools import partial
from typing import Callable, Optional

import numpy as np
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
Expand Down Expand Up @@ -396,7 +397,7 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal
]
)
# (num_layers, 2, num_heads, prefixlen, head_dim)
past_key_values = paddle.transpose(past_key_values, perm=[2, 1, 3, 0, 4])
past_key_values = paddle.transpose(past_key_values, perm=[2, 1, 3, 0, 4]).numpy()

if merge_tensor_parallel and self.model.config.tensor_parallel_degree > 1:
trainable_state_dict = self.prefix_encoder.state_dict()
Expand All @@ -421,7 +422,7 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal
if is_main_process:
self.prefix_config.save_pretrained(save_directory)
self.prefix_config.tensor_parallel_degree = self.model.config.tensor_parallel_degree
paddle.save({"past_key_values": past_key_values}, os.path.join(save_directory, PAST_KEY_VALUES_FILE_NAME))
np.save(os.path.join(save_directory, PAST_KEY_VALUES_FILE_NAME), past_key_values)

def set_state_dict(self, state_dict):
self.prefix_encoder.set_state_dict(state_dict)
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _get_bool_env(env_key: str, default_value: str) -> bool:
PREFIX_CONFIG_NAME = "prefix_config.json"
LORA_WEIGHT_FILE_NAME = "lora_model_state.pdparams"
PREFIX_WEIGHT_FILE_NAME = "prefix_model_state.pdparams"
PAST_KEY_VALUES_FILE_NAME = "past_key_values.pdparams"
PAST_KEY_VALUES_FILE_NAME = "pre_caches.npy"

# for conversion
ENABLE_TORCH_CHECKPOINT = _get_bool_env("ENABLE_TORCH_CHECKPOINT", "true")

0 comments on commit 1d776a6

Please sign in to comment.