Skip to content

Commit

Permalink
[LLM] Support pipeline parallel for llama model. (#5755)
Browse files Browse the repository at this point in the history
* init for llama  pipeline.

Co-authored-by: chenxuyi <[email protected]>
  • Loading branch information
ZHUI and Meiyim authored Apr 28, 2023
1 parent 7586abb commit 93e78c2
Show file tree
Hide file tree
Showing 13 changed files with 963 additions and 158 deletions.
31 changes: 31 additions & 0 deletions examples/language_model/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,37 @@ python -u -m paddle.distributed.fleet.launch \
--warmup_steps 20
```

## 流水线并行
```shell
python -u -m paddle.distributed.launch \
--gpus "4,5,6,7" finetune_generation.py \
--model_name_or_path facebook/tiny-random-llama \
--do_train \
--do_eval \
--num_train_epochs 1 \
--dataloader_num_workers 1 \
--gradient_accumulation_steps 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 16 \
--tensor_parallel_degree 2 \
--pipeline_parallel_degree 2 \
--pipeline_parallel_mirco_batch_size 1 \
--pipeline_parallel_config "disable_p2p_cache_shape" \
--overwrite_output_dir \
--output_dir ./checkpoints/ \
--logging_steps 1 \
--disable_tqdm 1 \
--eval_steps 100 \
--eval_with_do_generation 0 \
--fp16 0\
--fp16_opt_level O2 \
--recompute \
--learning_rate 3e-5 \
--lr_scheduler_type linear \
--max_grad_norm 1.0 \
--warmup_steps 20
```

<a name="3"></a>

## 模型预测
Expand Down
1 change: 1 addition & 0 deletions examples/language_model/llama/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def convert_example(example, tokenizer, data_args, is_eval=False):
input_seq + output_seq,
return_tensors="pd",
max_length=data_args.src_length + data_args.tgt_length,
padding="max_length" if data_args.always_pad_to_max_length else False,
truncation=True,
)

Expand Down
53 changes: 41 additions & 12 deletions examples/language_model/llama/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@

import paddle
from data import DataCollatorForSupervisedDataset, convert_example
from modeling_pp import LlamaForCausalLMPipe
from utils import LlamaTrainer, compute_metrics

from paddlenlp.datasets import load_dataset
from paddlenlp.layers import LoRAConfig, LoRAModel
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
from paddlenlp.trainer import (
PdArgumentParser,
TrainingArguments,
get_last_checkpoint,
set_seed,
)
from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
from paddlenlp.utils.log import logger

Expand Down Expand Up @@ -53,22 +59,30 @@ class ModelArgument:
model_name_or_path: str = field(
default="facebook/llama-7b", metadata={"help": "Build-in pretrained model name or the path to local model."}
)
label_smoothing: float = field(default=0.1, metadata={"help": "The label smoothing parameter."})
# label_smoothing: float = field(default=0.1, metadata={"help": "The label smoothing parameter."})
lr_decay_ratio: float = field(default=0.1, metadata={"help": "The ratio for learning rate decrease"})
lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"})
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
eval_with_do_generation: bool = field(
default=True, metadata={"help": "Evaluate with generation, instead for calc loss."}
)


def main():
parser = PdArgumentParser((ModelArgument, DataArgument, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
data_args.always_pad_to_max_length = False
# data_args.always_pad_to_max_length = training_args.pipeline_parallel_degree > 1

training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
setattr(training_args, "label_smoothing", model_args.label_smoothing)
# setattr(training_args, "label_smoothing", model_args.label_smoothing)
setattr(training_args, "lr_decay_ratio", model_args.lr_decay_ratio)

paddle.set_device(training_args.device)

set_seed(args=training_args)

# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
Expand Down Expand Up @@ -98,8 +112,14 @@ def main():
if training_args.bf16:
dtype = "bfloat16"

model_class = AutoModelForCausalLM
if training_args.pipeline_parallel_degree > 1:
if model_args.eval_with_do_generation and training_args.do_eval:
raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.")
model_class = LlamaForCausalLMPipe

# Load the pretrained language model.
model = AutoModelForCausalLM.from_pretrained(
model = model_class.from_pretrained(
model_args.model_name_or_path,
load_state_as_np=True,
low_cpu_mem_usage=True,
Expand All @@ -110,6 +130,7 @@ def main():
use_flash_attention=model_args.use_flash_attention,
use_recompute=training_args.recompute,
)

if model_args.lora:
# TODO: hardcode parameters for now. Change after MergedLoRA is introduced
lora_config = LoRAConfig(
Expand All @@ -131,11 +152,17 @@ def main():
tokenizer.pad_token = tokenizer.unk_token

# Load the dataset.
train_ds, dev_ds = load_dataset(data_args.task_name, splits=["train_v1", "dev_v1"])
if training_args.do_train or training_args.do_eval:
train_ds, dev_ds = load_dataset(data_args.task_name, splits=["train_v1", "dev_v1"])
trans_func = partial(convert_example, tokenizer=tokenizer, data_args=data_args)

if training_args.do_train:
train_ds = train_ds.map(partial(trans_func))
if training_args.do_eval:
# pipeline_parallel eval is the some as training.
is_eval = model_args.eval_with_do_generation
dev_ds = dev_ds.map(partial(trans_func, is_eval=is_eval))

trans_func = partial(convert_example, tokenizer=tokenizer, data_args=data_args)
train_ds = train_ds.map(partial(trans_func))
dev_ds = dev_ds.map(partial(trans_func, is_eval=True))
collate_fn = DataCollatorForSupervisedDataset(tokenizer)

def compute_metrics_trainer(eval_preds, tokenizer):
Expand Down Expand Up @@ -163,11 +190,13 @@ def compute_metrics_trainer(eval_preds, tokenizer):
trainer = LlamaTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=dev_ds,
train_dataset=train_ds if training_args.do_train else None,
eval_dataset=dev_ds if training_args.do_eval else None,
tokenizer=tokenizer,
compute_metrics=compute_metrics_func,
do_generation=True,
compute_metrics=compute_metrics_func
if (model_args.eval_with_do_generation and training_args.do_eval)
else None,
do_generation=model_args.eval_with_do_generation,
data_collator=collate_fn,
)

Expand Down
189 changes: 189 additions & 0 deletions examples/language_model/llama/modeling_pp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pass
import paddle
import paddle.distributed.fleet as fleet
import paddle.nn as nn
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer

from paddlenlp.transformers import PretrainedModel
from paddlenlp.transformers.llama.modeling import (
LlamaConfig,
LlamaDecoderLayer,
LlamaLMHead,
LlamaPretrainedModel,
LlamaPretrainingCriterion,
LlamaRMSNorm,
)


def get_hcg():
return fleet.get_hybrid_communicate_group()


class LlamaEmbedding(nn.Layer):
"""Extends LlamaEmbeddings to forward attention_mask through the pipeline."""

def __init__(self, config):
super(LlamaEmbedding, self).__init__()
if config.tensor_parallel_degree > 1:
self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
)
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)

def forward(self, input_ids):
"""_summary_
Args:
input (_type_): _description_
Returns:
_type_: _description_
"""
return self.embed_tokens(input_ids)


class PipelinePretrainedModel(PretrainedModel):
_sequential_layers = []
_pipeline_name_mapping = None

def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)

def add_sequential_layer(self, layer_desc, name_prefix=""):
self._sequential_layers.append({"layer": layer_desc, "name_prefix": name_prefix})

def get_sequential_layers(self):
return [x["layer"] for x in self._sequential_layers]

def get_sequential_name_prefixs(self):
return {str(index): x["name_prefix"] for index, x in enumerate(self._sequential_layers)}

def _set_pipeline_name_mapping(self, mappings=None):
if mappings is not None:
self._pipeline_name_mapping = mappings
else:
mapping = {}
state_dict_keys = list(super().state_dict().keys())
prefixs = self.get_sequential_name_prefixs()
for k in state_dict_keys:
name_splited = k.split(".")
name_splited[0] = prefixs[name_splited[0]]
mapping[".".join(name_splited)] = k
self._pipeline_name_mapping = mapping

return self._pipeline_name_mapping

def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
prefixs = self.get_sequential_name_prefixs()
for k in list(state_dict.keys()):
v = state_dict.pop(k)
name_splited = k.split(".")
name_splited[0] = prefixs[name_splited[0]]
state_dict[".".join(name_splited)] = v

return state_dict

def set_state_dict(self, state_dict, *args, **kwargs):
if self._pipeline_name_mapping is None:
self._set_pipeline_name_mapping()
assert len(self._pipeline_name_mapping) > 0, "The pipeline stage must have parameters!"

for k in list(state_dict.keys()):
v = state_dict.pop(k)
if k not in self._pipeline_name_mapping:
continue
state_dict[self._pipeline_name_mapping[k]] = v

return super().set_state_dict(state_dict, *args, **kwargs)


class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
"""LlamaForPretraining adapted for pipeline parallelism.
The largest change is flattening the LlamaModel class so we can express it as a
sequence of layers including embedding, transformer layers, and output.
"""

config_class = LlamaConfig

_get_tensor_parallel_mappings = LlamaPretrainedModel._get_tensor_parallel_mappings
# NO base_model_prefix !!!!

def __init__(
self,
config,
# num_partitions=1,
# topology=None,
use_recompute=True,
# fused_linear=False,
# fuse_attn_qkv=False,
# scale_qk_by_layer_num=True,
recompute_granularity="full",
virtual_pp_degree=1,
# sequence_parallel=False,
# no_recompute_layers=None,
pp_recompute_interval=1,
# use_flash_attn=False,
# fused_softmax_with_triangular=False,
):
self.config = config

hcg = get_hcg()
tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), 1)
tensor_parallel_rank = max(hcg.get_model_parallel_rank(), 0)

config.tensor_parallel_degree = tensor_parallel_degree
config.tensor_parallel_rank = tensor_parallel_rank

self.add_sequential_layer(LayerDesc(LlamaEmbedding, config=config), "llama")
for i in range(config.num_hidden_layers):
self.add_sequential_layer(LayerDesc(LlamaDecoderLayer, config=config), f"llama.layers.{i}")

self.add_sequential_layer(LayerDesc(LlamaRMSNorm, config=config), "llama.norm")
self.add_sequential_layer(LayerDesc(LlamaLMHead, config=config), "lm_head")

recompute_interval = 0
if use_recompute and recompute_granularity == "full":
assert pp_recompute_interval <= config.num_hidden_layers // (
virtual_pp_degree * get_hcg().topology().get_dim_size("pipe")
), "pp recompute interval should smaller than num layers of each pp chunk"
recompute_interval = pp_recompute_interval

seg_method = "layer:LlamaDecoderLayer"
if config.num_hidden_layers % get_hcg().topology().get_dim_size("pipe") != 0:
seg_method = "uniform"

PipelineLayer.__init__(
self,
layers=self.get_sequential_layers(),
loss_fn=LlamaPretrainingCriterion(config),
topology=get_hcg().topology(),
seg_method=seg_method,
recompute_interval=recompute_interval,
recompute_ctx={
"mp_group": get_hcg().get_model_parallel_group(),
"offload": False,
"partition": False,
},
num_virtual_pipeline_stages=virtual_pp_degree,
)
# DON'T init PipelinePretrainedModel
# PipelinePretrainedModel.__init__(self.super(), config=config)
Loading

0 comments on commit 93e78c2

Please sign in to comment.