diff --git a/.copyright.hook b/.copyright.hook index 6a57d7649d64..a472e581d3f3 100644 --- a/.copyright.hook +++ b/.copyright.hook @@ -71,7 +71,7 @@ RE_SHEBANG = re.compile(r"^[ \t\v]*#[ \t]?\!") def _check_copyright(path): head=[] try: - with open(path) as f: + with open(path, encoding="utf-8") as f: head = [next(f) for x in range(4)] except StopIteration: pass diff --git a/llm/argument.py b/llm/argument.py index 64e736873ca2..f0736902519a 100644 --- a/llm/argument.py +++ b/llm/argument.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field +from typing import List, Optional from paddlenlp.trainer import TrainingArguments from paddlenlp.trainer.trainer_utils import IntervalStrategy @@ -48,6 +49,9 @@ class DataArgument: dataset_name_or_path: str = field(default=None, metadata={"help": "Name or path for dataset"}) task_name: str = field(default=None, metadata={"help": "Additional name to select a more specific task."}) zero_padding: bool = field(default=False, metadata={"help": "Whether to use Zero Padding data stream"}) + pad_to_multiple_of: int = field( + default=None, metadata={"help": "If set will pad the sequence to a multiple of the provided value."} + ) src_length: int = field(default=1024, metadata={"help": "The maximum length of source(context) tokens."}) max_length: int = field( default=2048, @@ -102,6 +106,64 @@ class ModelArgument: default=None, metadata={"help": "Build-in pretrained model name or the path to local model."} ) use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"}) + tokenizer_name_or_path: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + use_fused_rms_norm: bool = field( + default=False, + metadata={"help": "llama or other model, use_fused_rms_norm"}, + ) + fuse_attention_qkv: bool = field( + default=False, + metadata={"help": "whether to fuse attention qkv"}, + ) + fuse_attention_ffn: bool = field( + default=False, + metadata={"help": "whether to fuse first up and gate proj in mlp block"}, + ) + recompute_granularity: str = field( + default="full", + metadata={"help": "Choose among ['full', 'core_attn', 'full_attn']"}, + ) + virtual_pp_degree: int = field( + default=1, + metadata={"help": "virtual_pp_degree"}, + ) + hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."}) + attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention hidden dropout prob."}) + + continue_training: bool = field( + default=False, + metadata={ + "help": "Pre-training from existing paddlenlp model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models." + }, + ) + sequence_parallel: bool = field( + default=False, + metadata={"help": "whether to use sequence parallel"}, + ) + fuse_sequence_parallel_allreduce: bool = field( + default=False, + metadata={"help": "whether to use fuse sequence parallel allreduce"}, + ) + use_fused_rope: Optional[bool] = field( + default=False, + metadata={"help": "Enable rope fusion or not."}, + ) + no_recompute_layers: Optional[List[int]] = field( + default=None, + metadata={"help": "Specify the full transformer layers that should not be recomputed."}, + ) + pp_recompute_interval: int = field( + default=1, + metadata={ + "help": "The interval for the number of layers at which recomputation occurs. A value of 0 indicates no recomputation. Default is 0." + }, + ) + recompute_use_reentrant: bool = field( + default=False, + metadata={"help": "recompute_use_reentrant"}, + ) weight_quantize_algo: str = field( default=None, metadata={ diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index c3396230c0a0..a3288038f5f4 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -154,6 +154,30 @@ def main(): if hasattr(model_config, "use_flash_attention"): model_config.use_flash_attention = model_args.use_flash_attention + model_config.use_fused_rms_norm = model_args.use_fused_rms_norm + model_config.fuse_attention_qkv = model_args.fuse_attention_qkv + model_config.fuse_attention_ffn = model_args.fuse_attention_ffn + model_config.recompute_granularity = model_args.recompute_granularity + model_config.virtual_pp_degree = model_args.virtual_pp_degree + model_config.sequence_parallel = model_args.sequence_parallel + model_config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce + model_config.use_fused_rope = model_args.use_fused_rope + + model_config.no_recompute_layers = model_args.no_recompute_layers + model_config.pp_recompute_interval = model_args.pp_recompute_interval + model_config.recompute_use_reentrant = model_args.recompute_use_reentrant + model_config.use_recompute = training_args.recompute + + model_config.tensor_parallel_degree = training_args.tensor_parallel_degree + model_config.tensor_parallel_rank = training_args.tensor_parallel_rank + + # Config for model using dropout, such as GPT. + model_config.hidden_dropout_prob = model_args.hidden_dropout_prob + model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob + + model_config.sep_parallel_degree = training_args.sep_parallel_degree + model_config.tensor_parallel_output = True + model_config.seq_length = data_args.max_length if not training_args.autotuner_benchmark: model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, @@ -494,6 +518,7 @@ def compute_metrics_do_generation(eval_preds): padding=padding, max_label_length=max_length, return_tensors="np", + pad_to_multiple_of=data_args.pad_to_multiple_of, ), do_generation=data_args.eval_with_do_generation, callbacks=[InTokensIterDatasetCallback()] if isinstance(train_ds, InTokensIterableDataset) else None, diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 77ea2929a31d..b98bc7f1ac83 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -16,6 +16,7 @@ from __future__ import annotations import math +import os import warnings from functools import partial from typing import Optional, Tuple @@ -75,8 +76,6 @@ def swiglu(x, y=None): try: if get_env_device() == "npu": - import os - from paddle.base import core for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")): @@ -94,6 +93,13 @@ def swiglu(x, y=None): ] +def is_mc2_valid(): + current_device = get_env_device() + if current_device == "npu": + return True + return False + + def _get_interleave(n): def _get_interleave_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) @@ -565,8 +571,17 @@ def __init__(self, config): self.fuse_attention_ffn = config.fuse_attention_ffn if config.sequence_parallel: - ColumnParallelLinear = ColumnSequenceParallelLinear - RowParallelLinear = RowSequenceParallelLinear + if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)): + from paddlenlp.transformers.mc2_seqence_parallel_linear import ( + MC2ColumnSeqParallelLinear, + MC2RowSeqParallelLinear, + ) + + ColumnParallelLinear = MC2ColumnSeqParallelLinear + RowParallelLinear = MC2RowSeqParallelLinear + else: + ColumnParallelLinear = ColumnSequenceParallelLinear + RowParallelLinear = RowSequenceParallelLinear else: ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear @@ -670,7 +685,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): ) self.use_fused_rope = config.use_fused_rope - if self.use_fused_rope: + if self.use_fused_rope and get_env_device() != "npu": if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: warnings.warn( "Enable fuse rope in the config, but fuse rope is not available. " @@ -679,8 +694,17 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): self.use_fused_rope = False if config.sequence_parallel: - ColumnParallelLinear = ColumnSequenceParallelLinear - RowParallelLinear = RowSequenceParallelLinear + if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)): + from paddlenlp.transformers.mc2_seqence_parallel_linear import ( + MC2ColumnSeqParallelLinear, + MC2RowSeqParallelLinear, + ) + + ColumnParallelLinear = MC2ColumnSeqParallelLinear + RowParallelLinear = MC2RowSeqParallelLinear + else: + ColumnParallelLinear = ColumnSequenceParallelLinear + RowParallelLinear = RowSequenceParallelLinear else: ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear @@ -1526,9 +1550,12 @@ def forward( attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype ) # [bs, 1, seq_len, seq_len] if self.config.use_flash_attention: - is_casual = is_casual_mask(attention_mask) - if is_casual and alibi is None: - attention_mask = None + if get_env_device != "npu": + is_casual = is_casual_mask(attention_mask) + if is_casual and alibi is None: + attention_mask = None + else: + attention_mask = attention_mask.astype("bool") hidden_states = inputs_embeds # decoder layers diff --git a/paddlenlp/transformers/mc2_seqence_parallel_linear.py b/paddlenlp/transformers/mc2_seqence_parallel_linear.py new file mode 100644 index 000000000000..7d669833e690 --- /dev/null +++ b/paddlenlp/transformers/mc2_seqence_parallel_linear.py @@ -0,0 +1,142 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle + +try: + import paddle_custom_device +except ImportError: + raise ImportError("Current device does not support MC2!") + +from paddle import distributed as dist +from paddle.autograd import PyLayer +from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + ColumnSequenceParallelLinear, + RowSequenceParallelLinear, +) + +__all_gather_recomputation__ = False +if int(os.getenv("MC2_Recompute", 0)): + __all_gather_recomputation__ = True + + +class MC2Column(PyLayer): + @staticmethod + def forward(ctx, input_, weight, group): + ctx.save_for_backward(input_, weight) + + rank = dist.get_rank() + hcomm_info = group.process_group.get_comm_name(rank) + + world_size = group.nranks + output, gather_out = paddle_custom_device.npu.fused_allgather_mm( + input_, + weight, + bias=None, + hcom=hcomm_info, + world_size=world_size, + gather_index=0, + gather_output=(not __all_gather_recomputation__), + comm_turn=0, + ) + + ctx.all_gather_output = gather_out + ctx.world_size = world_size + ctx.group = group + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensor() + + if __all_gather_recomputation__: + dim_size = input_.shape + dim_size[0] = dim_size[0] * ctx.world_size + all_gather_output = paddle.empty(dim_size, dtype=input_.dtype) + all_gather_output.stop_gradient = True + all_gather_work = dist.stream.all_gather(all_gather_output, input_, group=ctx.group, sync_op=False) + else: + all_gather_output = ctx.all_gather_output + + grad_input = paddle.matmul(grad_output, weight, transpose_y=True) + sub_grad_input = paddle.empty(input_.shape, dtype=input_.dtype) + reduce_scatter_work = dist.stream.reduce_scatter(sub_grad_input, grad_input, group=ctx.group, sync_op=False) + + if __all_gather_recomputation__: + all_gather_work.wait() + + grad_weight = paddle.matmul(all_gather_output, grad_output, transpose_x=True) + reduce_scatter_work.wait() + + return sub_grad_input, grad_weight + + +class MC2Row(PyLayer): + @staticmethod + def forward(ctx, input_, weight, group): + ctx.save_for_backward(input_, weight) + + rank = dist.get_rank() + hcomm_info = group.process_group.get_comm_name(rank) + world_size = group.nranks + + output = paddle_custom_device.npu.fused_mm_reduce_scatter( + input_, + weight, + bias=None, + hcom=hcomm_info, + world_size=world_size, + reduce_op="sum", + comm_turn=0, + ) + + ctx.hcomm_info = hcomm_info + ctx.world_size = world_size + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensor() + hcomm_info = ctx.hcomm_info + world_size = ctx.world_size + + grad_input, all_gather_grad_output = paddle_custom_device.npu.fused_allgather_mm( + grad_output, + weight.t(), + bias=None, + hcom=hcomm_info, + world_size=world_size, + gather_index=0, + gather_output=True, + comm_turn=0, + ) + grad_weight = paddle.matmul(input_, all_gather_grad_output, transpose_x=True) + + return grad_input, grad_weight + + +class MC2ColumnSeqParallelLinear(ColumnSequenceParallelLinear): + def forward(self, x): + output = MC2Column.apply(x, self.weight, self.model_parallel_group) + output = output + self.bias if self.bias is not None else output + return output + + +class MC2RowSeqParallelLinear(RowSequenceParallelLinear): + def forward(self, x): + output = MC2Row.apply(x, self.weight, self.model_parallel_group) + output = output + self.bias if self.bias is not None else output + return output