Skip to content

Commit

Permalink
[LLM] Unify pipeline model with PretrainModelPipe (#7095)
Browse files Browse the repository at this point in the history
* unify pipeline model with PretrainModelPipe.
  • Loading branch information
ZHUI authored Sep 22, 2023
1 parent cdc7382 commit 51835e8
Show file tree
Hide file tree
Showing 17 changed files with 277 additions and 257 deletions.
4 changes: 2 additions & 2 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def main():
if training_args.pipeline_parallel_degree > 1:
if data_args.eval_with_do_generation and training_args.do_eval:
raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.")
from llama.modeling_pp import LlamaForCausalLMPipe
from paddlenlp.transformers import AutoModelForCausalLMPipe

model = LlamaForCausalLMPipe.from_pretrained(
model = AutoModelForCausalLMPipe.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=False,
tensor_parallel_degree=training_args.tensor_parallel_degree,
Expand Down
7 changes: 4 additions & 3 deletions llm/gpt-3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ mv gpt_en_dataset_300m_idx.npz ./data

注意:
1. 需要paddle develop版本训练,需要安装`pip install tool_helpers visualdl==2.5.3`等相关缺失whl包
2. `use_flash_attention` 需要在A100机器开启,否则loss可能不正常(很快变成0.00x,非常小不正常)。建议使用cuda11.8环境。
2. `use_flash_attention` 需要在A100机器开启。建议使用cuda11.8环境。

使用下面脚本,即可在gpt2-medium-en的基础上,继续训练.
```shell
Expand All @@ -35,7 +35,7 @@ log_dir="log"
rm -rf $log_dir

python -u -m paddle.distributed.launch \
--gpus "0" \
--gpus "0,1,2,3,4,5,6,7" \
--log_dir ${log_dir} \
run_pretrain.py \
--model_type "gpt" \
Expand All @@ -49,7 +49,7 @@ python -u -m paddle.distributed.launch \
--per_device_eval_batch_size 1 \
--tensor_parallel_degree 1 \
--pipeline_parallel_degree 1 \
--fuse_attention_qkv 1 \
--fuse_attention_qkv 0 \
--use_flash_attention 0 \
--fp16 \
--fp16_opt_level "O2" \
Expand All @@ -62,6 +62,7 @@ python -u -m paddle.distributed.launch \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1\
--continue_training \
--dataloader_num_workers 1 \
--sharding "stage2" \
--eval_steps 1000 \
Expand Down
8 changes: 6 additions & 2 deletions llm/gpt-3/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from functools import partial

import paddle
from modeling_pp import GPTForCausalLMPipe
from utils import (
DataCollatorForSupervisedDataset,
GPTTrainer,
Expand All @@ -34,7 +33,12 @@
get_last_checkpoint,
set_seed,
)
from paddlenlp.transformers import AutoTokenizer, GPTConfig, GPTForCausalLM
from paddlenlp.transformers import (
AutoTokenizer,
GPTConfig,
GPTForCausalLM,
GPTForCausalLMPipe,
)
from paddlenlp.utils.log import logger

MODEL_CLASSES = {
Expand Down
36 changes: 28 additions & 8 deletions llm/gpt-3/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import Optional

import paddle
from modeling_pp import GPTForCausalLMPipe

from paddlenlp.trainer import (
PdArgumentParser,
Expand All @@ -36,6 +35,7 @@
CosineAnnealingWithWarmupDecay,
GPTConfig,
GPTForCausalLM,
GPTForCausalLMPipe,
LinearAnnealingWithWarmupDecay,
)
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
Expand Down Expand Up @@ -125,6 +125,16 @@ class ModelArguments:
)
output_attentions: bool = field(default=False, metadata={"help": "Whether output attention weights"})
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
virtual_pp_degree: int = field(
default=1,
metadata={"help": "virtual_pp_degree"},
)
continue_training: bool = field(
default=False,
metadata={
"help": "Pre-training from existing paddlenlp model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models."
},
)
fused_linear: bool = field(
default=False,
metadata={"help": "gpt, whether to fuse linear projection"},
Expand Down Expand Up @@ -343,7 +353,16 @@ def main():
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)

config = config_class.from_pretrained(model_args.model_name_or_path)
# There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings
if not model_args.continue_training:
config.max_position_embeddings = max(config.max_position_embeddings, data_args.max_seq_length)

if not model_args.continue_training:
config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128)
logger.info(f"Reset vocab size to {config.vocab_size} for batter amp peformance.")

config.output_attentions = model_args.output_attentions
config.virtual_pp_degree = model_args.virtual_pp_degree
config.max_position_embeddings = max(config.max_position_embeddings, data_args.max_seq_length)
config.hidden_dropout_prob = model_args.hidden_dropout_prob
config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob
Expand All @@ -368,12 +387,14 @@ def main():
if training_args.pipeline_parallel_degree > 1:
model_class = GPTForCausalLMPipe

model = model_class.from_pretrained(
model_args.model_name_or_path,
config=config,
dtype=dtype,
load_state_as_np=True,
)
if model_args.continue_training:
model = model_class.from_pretrained(
model_args.model_name_or_path,
config=config,
dtype=dtype,
)
else:
model = model_class._from_config(config, dtype=dtype)

# Create the learning_rate sheduler and optimizer
if training_args.decay_steps is None:
Expand Down Expand Up @@ -418,7 +439,6 @@ def main():
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
checkpoint = None

# Training
if training_args.do_train:
Expand Down
9 changes: 6 additions & 3 deletions llm/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
compute_metrics,
compute_metrics_not_do_generation,
)
from modeling_pp import LlamaForCausalLMPipe

from paddlenlp.data import DataCollatorForSeq2Seq
from paddlenlp.datasets import load_dataset
Expand All @@ -37,7 +36,11 @@
get_last_checkpoint,
set_seed,
)
from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
from paddlenlp.transformers import (
AutoModelForCausalLM,
AutoModelForCausalLMPipe,
AutoTokenizer,
)
from paddlenlp.utils.log import logger


Expand Down Expand Up @@ -207,7 +210,7 @@ def main():
if training_args.pipeline_parallel_degree > 1:
if model_args.eval_with_do_generation and training_args.do_eval:
raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.")
model_class = LlamaForCausalLMPipe
model_class = AutoModelForCausalLMPipe

# Load the pretrained language model.
model = model_class.from_pretrained(
Expand Down
6 changes: 3 additions & 3 deletions llm/llama/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
LinearAnnealingWithWarmupDecay,
LlamaConfig,
LlamaForCausalLM,
LlamaForCausalLMPipe,
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
Expand All @@ -50,8 +51,6 @@
),
}

from fused_layers import mock_layers
from modeling_pp import LlamaForCausalLMPipe

from paddlenlp.data.causal_dataset import build_train_valid_test_datasets, print_rank_0

Expand Down Expand Up @@ -371,6 +370,8 @@ def main():
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if training_args.enable_linear_fused_grad_add:
from fused_layers import mock_layers

mock_layers()

if model_args.tokenizer_name_or_path is None:
Expand Down Expand Up @@ -465,7 +466,6 @@ def main():
model_args.model_name_or_path,
config=config,
dtype=dtype,
load_state_as_np=True,
)
else:
model = model_class._from_config(config, dtype=dtype)
Expand Down
3 changes: 1 addition & 2 deletions llm/llama/tests/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
import numpy as np
import paddle
import paddle.distributed.fleet as fleet
from modeling_pp import LlamaForCausalLMPipe
from paddle.distributed.fleet.meta_parallel.pipeline_parallel import PipelineParallel

from paddlenlp.transformers import LlamaForCausalLM
from paddlenlp.transformers import LlamaForCausalLM, LlamaForCausalLMPipe


class TestLlama(unittest.TestCase):
Expand Down
3 changes: 1 addition & 2 deletions llm/llama/tests/test_sequence_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
import numpy as np
import paddle
import paddle.distributed.fleet as fleet
from modeling_pp import LlamaForCausalLMPipe
from paddle.distributed.fleet.meta_parallel.pipeline_parallel import PipelineParallel

from paddlenlp.transformers import LlamaConfig, LlamaForCausalLM
from paddlenlp.transformers import LlamaConfig, LlamaForCausalLM, LlamaForCausalLMPipe


class TestLlama(unittest.TestCase):
Expand Down
8 changes: 2 additions & 6 deletions paddlenlp/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@
from .bert.configuration import *

# isort: split
from .gpt.modeling import *
from .gpt.tokenizer import *
from .gpt.configuration import *
from .gpt import *
from .roberta.modeling import *
from .roberta.tokenizer import *
from .roberta.configuration import *
Expand Down Expand Up @@ -120,9 +118,7 @@
from .funnel.modeling import *
from .funnel.tokenizer import *
from .funnel.configuration import *
from .llama.configuration import *
from .llama.modeling import *
from .llama.tokenizer import *
from .llama import *
from .layoutlm.configuration import *
from .layoutlm.modeling import *
from .layoutlm.tokenizer import *
Expand Down
71 changes: 49 additions & 22 deletions paddlenlp/transformers/auto/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"AutoModelForMultipleChoice",
"AutoModelForMaskedLM",
"AutoModelForCausalLM",
"AutoModelForCausalLMPipe",
"AutoEncoder",
"AutoDecoder",
"AutoGenerator",
Expand Down Expand Up @@ -140,6 +141,7 @@
("ForMultipleChoice", "AutoModelForMultipleChoice"),
("ForMaskedLM", "AutoModelForMaskedLM"),
("ForCausalLM", "AutoModelForCausalLM"),
("ForCausalLMPipe", "AutoModelForCausalLMPipe"),
("Encoder", "AutoEncoder"),
("Decoder", "AutoDecoder"),
("Generator", "AutoGenerator"),
Expand Down Expand Up @@ -243,17 +245,22 @@ def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file
model_class = getattr(import_class, init_class)
return model_class
except AttributeError as err:
logger.error(err)
all_model_classes = import_class.__all__
all_tasks = {get_task_name(m) for m in all_model_classes if get_task_name(m) is not None}
raise AttributeError(
f"module '{import_class.__name__}' only supports the following classes: "
+ ", ".join(m for m in all_model_classes)
+ "\n"
"Hint: you can use interface "
+ " or ".join(task + ".from_pretrained" for task in all_tasks)
+ f" to load '{pretrained_model_name_or_path}'\n"
)
try:
new_import_class = importlib.import_module(f"paddlenlp.transformers.{class_name}")
model_class = getattr(new_import_class, init_class)
return model_class
except AttributeError:
logger.error(err)
all_model_classes = import_class.__all__
all_tasks = {get_task_name(m) for m in all_model_classes if get_task_name(m) is not None}
raise AttributeError(
f"module '{import_class.__name__}' only supports the following classes: "
+ ", ".join(m for m in all_model_classes)
+ "\n"
"Hint: you can use interface "
+ " or ".join(task + ".from_pretrained" for task in all_tasks)
+ f" to load '{pretrained_model_name_or_path}'\n"
)

@classmethod
def _from_pretrained(
Expand Down Expand Up @@ -313,17 +320,23 @@ def _from_pretrained(
try:
model_class = getattr(import_class, init_class)
except AttributeError as err:
logger.error(err)
all_model_classes = import_class.__all__
all_tasks = {get_task_name(m) for m in all_model_classes if get_task_name(m) is not None}
raise AttributeError(
f"module '{import_class.__name__}' only supports the following classes: "
+ ", ".join(m for m in all_model_classes)
+ "\n"
"Hint: you can use interface "
+ " or ".join(task + ".from_pretrained" for task in all_tasks)
+ f" to load '{pretrained_model_name_or_path}'\n"
)
try:
import_class2 = importlib.import_module(f"paddlenlp.transformers.{class_name}")
model_class = getattr(import_class2, init_class)
except AttributeError:
logger.error(err)
all_model_classes = import_class.__all__
all_tasks = {
get_task_name(m) for m in all_model_classes if get_task_name(m) is not None
}
raise AttributeError(
f"module '{import_class.__name__}' only supports the following classes: "
+ ", ".join(m for m in all_model_classes)
+ "\n"
"Hint: you can use interface "
+ " or ".join(task + ".from_pretrained" for task in all_tasks)
+ f" to load '{pretrained_model_name_or_path}'\n"
)
logger.info(f"We are using {model_class} to load '{pretrained_model_name_or_path}'.")
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
# From local dir path
Expand Down Expand Up @@ -819,6 +832,20 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
return cls._from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)


class AutoModelForCausalLMPipe(_BaseAutoModelClass):
"""
Pipeline model for AutoModelForCausalLM.
"""

CONFIGURATION_MODEL_MAPPING = get_init_configurations()
_pretrained_model_dict = CONFIGURATION_MODEL_MAPPING
_name_mapping = get_name_mapping("ForCausalLMPipe")

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
return cls._from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)


class AutoEncoder(_BaseAutoModelClass):
"""
AutoEncoder.
Expand Down
5 changes: 5 additions & 0 deletions paddlenlp/transformers/gpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .configuration import *
from .modeling import *
from .modeling_pp import *
from .tokenizer import *
4 changes: 3 additions & 1 deletion paddlenlp/transformers/gpt/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,12 @@ def __init__(
output_attentions: bool = False,
ignore_index: int = 0,
use_flash_attention: bool = False,
use_fused_dropout_add: bool = False,
fused_linear: bool = False,
fuse_attention_qkv=False,
enable_fuse_transformer: bool = False,
use_fused_dropout_add: bool = False,
fused_softmax_with_triangular: bool = False,
virtual_pp_degree: int = 1,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
Expand Down Expand Up @@ -366,3 +367,4 @@ def __init__(
self.enable_fuse_transformer = enable_fuse_transformer
self.use_fused_dropout_add = use_fused_dropout_add
self.fused_softmax_with_triangular = fused_softmax_with_triangular
self.virtual_pp_degree = virtual_pp_degree
Loading

0 comments on commit 51835e8

Please sign in to comment.