Skip to content

Commit

Permalink
add load save prefix (PaddlePaddle#5797)
Browse files Browse the repository at this point in the history
  • Loading branch information
lugimzzz authored May 4, 2023
1 parent 93e78c2 commit abb705e
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 45 deletions.
6 changes: 3 additions & 3 deletions paddlenlp/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ class LoRAConfig:
"help": "Provides fine-grained control over `MergedLoRALinear`. If None, `LoRALinear` is used instead."
},
)
tensor_parallel_degree: int = field(default=1, metadata={"help": "1 for not use tensor parallel"})
tensor_parallel_degree: int = field(default=-1, metadata={"help": "1 for not use tensor parallel"})
dtype: Optional[str] = field(default=None, metadata={"help": "The data type of tensor"})

@property
Expand Down Expand Up @@ -643,7 +643,7 @@ def from_pretrained(cls, model, lora_path):
)

# convert parameters to tensor parallel for mp model
if lora_config_tensor_parallel_degree == 1 and model.config.tensor_parallel_degree > 1:
if lora_config_tensor_parallel_degree <= 1 and model.config.tensor_parallel_degree > 1:
lora_state_dict = lora_model._convert_tensor_parallel(lora_state_dict=lora_state_dict)

# set lora state dict
Expand Down Expand Up @@ -691,7 +691,7 @@ def _merge_trainable_tensor_parallel(self, trainable_state_dict):
if self.model.config.tensor_parallel_rank != 0:
logger.info("Saving with merge_tensor_parallel, tensor_parallel_rank > 0 don't need save")
return
self.lora_config.tensor_parallel_degree = 1
self.lora_config.tensor_parallel_degree = -1
return trainable_state_dict

def _convert_tensor_parallel(self, lora_state_dict):
Expand Down
188 changes: 148 additions & 40 deletions paddlenlp/prompt/prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@

import json
import os
import re
from collections import OrderedDict
from dataclasses import asdict, dataclass, field
from typing import Callable, List, Optional
from functools import partial
from typing import Callable, Optional

import paddle
import paddle.nn as nn
from paddle.distributed import fleet

from ..utils.env import PREFIX_CONFIG_NAME
from ..transformers.model_utils import _add_variant
from ..utils.distributed import distributed_gather
from ..utils.env import PREFIX_CONFIG_NAME, PREFIX_WEIGHT_FILE_NAME
from ..utils.log import logger
from .prompt_utils import signature

Expand All @@ -35,13 +36,6 @@

@dataclass
class PrefixConfig:
trainable_modules: Optional[List[str]] = field(
default=None,
metadata={
"help": "List of module names or regex expression of the module names to train when applying with Prefix Tuning."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
},
)
prefix_dropout: float = field(default=0.0, metadata={"help": "Prefix projection dropout"})
num_prefix_tokens: Optional[int] = field(default=None, metadata={"help": "Number of prefix tokens"})
num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"})
Expand All @@ -53,7 +47,7 @@ class PrefixConfig:
prefix_projection_hidden_size: Optional[int] = field(
default=None, metadata={"help": "The hidden embedding dimension of the transformer model"}
)
tensor_parallel_degree: int = field(default=1, metadata={"help": ("1 for not use tensor parallel")})
tensor_parallel_degree: int = field(default=-1, metadata={"help": ("1 for not use tensor parallel")})
dtype: Optional[str] = field(default=None, metadata={"help": "The data type of tensor"})

@property
Expand Down Expand Up @@ -145,6 +139,11 @@ def __init__(
self.inference = False
self.postprocess_past_key_value = postprocess_past_key_value
self.pad_attention_mask = pad_attention_mask
if self.prefix_config.tensor_parallel_degree != self.model.config.tensor_parallel_degree:
self.prefix_config.tensor_parallel_degree = self.model.config.tensor_parallel_degree
logger.warning(
f"Reset tensor_parallel_degree of prefix_config to {self.model.config.tensor_parallel_degree}."
)

def forward(
self,
Expand Down Expand Up @@ -212,15 +211,11 @@ def _prepare_inputs_for_generation(self, *args, **kwargs):
return model_kwargs

def mark_only_prefix_as_trainable(self) -> None:
for name, weight in self.model.state_dict().items():
if self.prefix_config.trainable_modules is not None and any(
re.fullmatch(trainable_module, name) for trainable_module in self.lora_config.trainable_modules
):
weight.stop_gradient = False
else:
weight.stop_gradient = True

for name, weight in self.prefix_encoder.state_dict().items():
# freeze pretrained model
for _, weight in self.model.state_dict().items():
weight.stop_gradient = True
# train prefix encoder only
for _, weight in self.prefix_encoder.state_dict().items():
weight.stop_gradient = False

def _create_prefix_encoder(self):
Expand Down Expand Up @@ -281,16 +276,21 @@ def _get_past_key_values(self, batch_size):

# (bs, prefixlen, hidden_dim*layer_num*2/tensor_parallel_degree)
if self.config.tensor_parallel_degree > 1:
split_past_key_values = past_key_values.split(axis=2)
split_past_key_values = past_key_values.split(
num_or_sections=self.config.tensor_parallel_degree, axis=2
)
past_key_values = split_past_key_values[self.model.config.tensor_parallel_rank]
num_attention_heads = self.prefix_config.num_attention_heads // self.config.tensor_parallel_degree
else:
num_attention_heads = self.prefix_config.num_attention_heads

# (bs, prefixlen, layer_num*2, head_num/tensor_parallel_degree, head_dim)
past_key_values = past_key_values.reshape(
[
batch_size,
self.prefix_config.num_prefix_tokens,
self.prefix_config.num_hidden_layers * 2,
self.prefix_config.num_attention_heads // self.config.tensor_parallel_degree,
num_attention_heads,
self.prefix_config.hidden_size // self.prefix_config.num_attention_heads,
]
)
Expand All @@ -308,51 +308,159 @@ def eval(self):
self.model.eval()
self.prefix_encoder.eval()

def get_model_trainable_state_dict(self):
trainable_state_dict = OrderedDict()
for name, weight in self.model.state_dict().items():
if not weight.stop_gradient:
trainable_state_dict[name] = weight
return trainable_state_dict

def print_trainable_parameters(self) -> None:
freeze_numel = 0
trainable_numel = 0
for name, weight in self.model.state_dict().items():
for _, weight in self.model.state_dict().items():
if weight.stop_gradient:
freeze_numel += weight.numel().item()
else:
trainable_numel += weight.numel().item()
print(name, weight.shape)
for name, weight in self.prefix_encoder.state_dict().items():
for _, weight in self.prefix_encoder.state_dict().items():
if weight.stop_gradient:
freeze_numel += weight.numel().item()
else:
trainable_numel += weight.numel().item()
print(name, weight.shape)
logger.info(
f"Frozen parameters: {freeze_numel:.2e} || Trainable parameters:{trainable_numel:.2e} || Total parameters:{freeze_numel+trainable_numel:.2e}|| Trainable:{trainable_numel / (freeze_numel+trainable_numel):.2%}"
)

@classmethod
def from_pretrained(cls, model, prefix_path):
def from_pretrained(
cls,
model,
prefix_path,
postprocess_past_key_value=None,
pad_attention_mask=None,
):
# TODO(lugimzzz): support load past_key_values in next PR
# init prefix config & prefix model
prefix_config = PrefixConfig.from_pretrained(prefix_path)
prefix_model = cls(model, prefix_config)
# TODO(lugimzzz): support laod prefix_encoder parameter and past_key_values
# TODO(lugimzzz): support mp
# define a new variable to conserve original prefix_config.tensor_parallel_degree value which will update while initializing prefix model
prefix_config_tensor_parallel_degree = prefix_config.tensor_parallel_degree
prefix_model = cls(model, prefix_config, postprocess_past_key_value, pad_attention_mask)

# define prefix weight name
if prefix_config_tensor_parallel_degree > 1:
prefix_weight_name = _add_variant(PREFIX_WEIGHT_FILE_NAME, f"tp{model.config.tensor_parallel_rank:0>2d}")
else:
prefix_weight_name = PREFIX_WEIGHT_FILE_NAME

# load and set prefix weight parameter
prefix_weight_path = os.path.join(prefix_path, prefix_weight_name)
if os.path.exists(prefix_weight_path):
# load prefix weight parameter
prefix_state_dict = paddle.load(prefix_weight_path, return_numpy=True)
logger.info(f"Loading the prefix weights from {prefix_weight_path}")

if (
prefix_config_tensor_parallel_degree > 1
and prefix_config_tensor_parallel_degree != model.config.tensor_parallel_degree
):
raise NotImplementedError(
f"{prefix_config_tensor_parallel_degree} is not equal to {model.config.tensor_parallel_degree}. Please merge prefix weights first."
)

# convert parameters to tensor parallel for mp model
if prefix_config_tensor_parallel_degree <= 1 and model.config.tensor_parallel_degree > 1:
prefix_state_dict = prefix_model._convert_tensor_parallel(prefix_state_dict=prefix_state_dict)

# set prefix state dict
prefix_model.model.set_state_dict(prefix_state_dict)
else:
logger.error(f"prefix weights not found under {prefix_path}, creating prefix weights from scratch")

return prefix_model

def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = False):
# TODO(lugimzzz): support load prefix_encoder parameter and past_key_values
# TODO(lugimzzz): support mp
assert not os.path.isfile(
save_directory
), f"Saving directory ({save_directory}) should be a directory, not a file"
os.makedirs(save_directory, exist_ok=True)
prefix_weight_name = PREFIX_WEIGHT_FILE_NAME
if merge_tensor_parallel and self.model.config.tensor_parallel_degree > 1:
trainable_state_dict = self.prefix_encoder.state_dict()
trainable_state_dict = self._merge_trainable_tensor_parallel(trainable_state_dict)
else:
trainable_state_dict = self.prefix_encoder.state_dict()
if self.model.config.tensor_parallel_degree > 1:
prefix_weight_name = _add_variant(
PREFIX_WEIGHT_FILE_NAME, f"tp{self.model.config.tensor_parallel_rank:0>2d}"
)
weight_filename = os.path.join(save_directory, prefix_weight_name)
paddle.save(trainable_state_dict, weight_filename)

if self.model.config.tensor_parallel_rank == 0:
self.prefix_config.save_pretrained(save_directory)
self.prefix_config.tensor_parallel_degree = self.model.config.tensor_parallel_degree

def _merge_trainable_tensor_parallel(self, trainable_state_dict):
from paddlenlp.transformers.conversion_utils import split_or_merge_func

fn = split_or_merge_func(
is_split=False,
tensor_parallel_degree=self.model.config.tensor_parallel_degree,
tensor_parallel_rank=self.model.config.tensor_parallel_rank,
num_attention_heads=self.model.config.num_attention_heads,
)
if self.prefix_config.prefix_projection:
name_action_mappings = {
"0.weight": partial(fn, is_column=False),
"1.weight": partial(fn, is_column=True),
"1.bias": partial(fn, is_column=True),
"3.weight": partial(fn, is_column=False),
}
else:
name_action_mappings = {
"0.weight": partial(fn, is_column=False),
}
hcg = paddle.distributed.fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
is_dst = paddle.distributed.get_rank(mp_group) == 0

for key in trainable_state_dict:
tensor = trainable_state_dict[key]
if key in name_action_mappings:
ret = distributed_gather(tensor, group=mp_group, offload=True)
action = name_action_mappings[key]
tensor = action(ret) if is_dst else None
trainable_state_dict[key] = tensor
else:
trainable_state_dict[key] = tensor.numpy() if is_dst else None

if self.model.config.tensor_parallel_rank != 0:
logger.info("Saving with merge_tensor_parallel, tensor_parallel_rank > 0 don't need save")
return
self.prefix_config.tensor_parallel_degree = -1
return trainable_state_dict

def _convert_tensor_parallel(self, prefix_state_dict):
from paddlenlp.transformers.conversion_utils import split_or_merge_func

fn = split_or_merge_func(
is_split=True,
tensor_parallel_degree=self.model.config.tensor_parallel_degree,
tensor_parallel_rank=self.model.config.tensor_parallel_rank,
num_attention_heads=self.model.config.num_attention_heads,
)

if self.prefix_config.prefix_projection:
name_action_mappings = {
"0.weight": partial(fn, is_column=False),
"1.weight": partial(fn, is_column=True),
"1.bias": partial(fn, is_column=True),
"3.weight": partial(fn, is_column=False),
}
else:
name_action_mappings = {
"0.weight": partial(fn, is_column=False),
}

for name, action in name_action_mappings.items():
tensor = prefix_state_dict.pop(name)
prefix_state_dict[name] = action(tensor)
return prefix_state_dict


def bloom_postprocess_past_key_value(past_key_values):
# (layer_num, bs, head_num/tensor_parallel_degree, prefixlen, head_dim)*2
Expand Down
7 changes: 6 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

from ..data import DataCollator, DataCollatorWithPadding, default_data_collator
from ..layers.lora import LoRAModel
from ..prompt import PrefixModelForCausalLM
from ..transformers.model_utils import PretrainedModel, _add_variant, unwrap_model
from ..transformers.tokenizer_utils import PretrainedTokenizer
from ..utils import device_guard
Expand Down Expand Up @@ -1662,7 +1663,11 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_

merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel

if not isinstance(self.model, PretrainedModel) and not isinstance(self.model, LoRAModel):
if (
not isinstance(self.model, PretrainedModel)
and not isinstance(self.model, LoRAModel)
and not isinstance(self.model, PrefixModelForCausalLM)
):
if isinstance(unwrap_model(self.model), PretrainedModel):
unwrap_model(self.model).save_pretrained(
output_dir,
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def __init__(self, **kwargs):
self.dtype = kwargs.pop("dtype", paddle.get_default_dtype())

# Parameters for tensor parallel
self.tensor_parallel_degree = kwargs.pop("tensor_parallel_degree", 1)
self.tensor_parallel_degree = kwargs.pop("tensor_parallel_degree", -1)
self.tensor_parallel_rank = kwargs.pop("tensor_parallel_rank", 0)
# If set to True, this option is used with fleet.meta_parallel.ParallelCrossEntropy
# to calculate cross-entropy loss for parallel model.
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _get_bool_env(env_key: str, default_value: str) -> bool:
LORA_CONFIG_NAME = "lora_config.json"
PREFIX_CONFIG_NAME = "prefix_config.json"
LORA_WEIGHT_FILE_NAME = "lora_model_state.pdparams"
PREFIX_WEIGHT_FILE_NAME = "prefix_model_state.pdparams"

# for conversion
ENABLE_TORCH_CHECKPOINT = _get_bool_env("ENABLE_TORCH_CHECKPOINT", "true")

0 comments on commit abb705e

Please sign in to comment.