Skip to content

Commit

Permalink
add mc2 & finetune fused
Browse files Browse the repository at this point in the history
  • Loading branch information
NINGBENZHE committed Apr 3, 2024
1 parent 2273ee7 commit a356945
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .copyright.hook
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ RE_SHEBANG = re.compile(r"^[ \t\v]*#[ \t]?\!")
def _check_copyright(path):
head=[]
try:
with open(path) as f:
with open(path, encoding="utf-8") as f:
head = [next(f) for x in range(4)]
except StopIteration:
pass
Expand Down
62 changes: 62 additions & 0 deletions llm/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import List, Optional

from paddlenlp.trainer import TrainingArguments
from paddlenlp.trainer.trainer_utils import IntervalStrategy
Expand Down Expand Up @@ -48,6 +49,9 @@ class DataArgument:
dataset_name_or_path: str = field(default=None, metadata={"help": "Name or path for dataset"})
task_name: str = field(default=None, metadata={"help": "Additional name to select a more specific task."})
zero_padding: bool = field(default=False, metadata={"help": "Whether to use Zero Padding data stream"})
pad_to_multiple_of: int = field(
default=None, metadata={"help": "If set will pad the sequence to a multiple of the provided value."}
)
src_length: int = field(default=1024, metadata={"help": "The maximum length of source(context) tokens."})
max_length: int = field(
default=2048,
Expand Down Expand Up @@ -102,6 +106,64 @@ class ModelArgument:
default=None, metadata={"help": "Build-in pretrained model name or the path to local model."}
)
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
tokenizer_name_or_path: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_fused_rms_norm: bool = field(
default=False,
metadata={"help": "llama or other model, use_fused_rms_norm"},
)
fuse_attention_qkv: bool = field(
default=False,
metadata={"help": "whether to fuse attention qkv"},
)
fuse_attention_ffn: bool = field(
default=False,
metadata={"help": "whether to fuse first up and gate proj in mlp block"},
)
recompute_granularity: str = field(
default="full",
metadata={"help": "Choose among ['full', 'core_attn', 'full_attn']"},
)
virtual_pp_degree: int = field(
default=1,
metadata={"help": "virtual_pp_degree"},
)
hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."})
attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention hidden dropout prob."})

continue_training: bool = field(
default=False,
metadata={
"help": "Pre-training from existing paddlenlp model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models."
},
)
sequence_parallel: bool = field(
default=False,
metadata={"help": "whether to use sequence parallel"},
)
fuse_sequence_parallel_allreduce: bool = field(
default=False,
metadata={"help": "whether to use fuse sequence parallel allreduce"},
)
use_fused_rope: Optional[bool] = field(
default=False,
metadata={"help": "Enable rope fusion or not."},
)
no_recompute_layers: Optional[List[int]] = field(
default=None,
metadata={"help": "Specify the full transformer layers that should not be recomputed."},
)
pp_recompute_interval: int = field(
default=1,
metadata={
"help": "The interval for the number of layers at which recomputation occurs. A value of 0 indicates no recomputation. Default is 0."
},
)
recompute_use_reentrant: bool = field(
default=False,
metadata={"help": "recompute_use_reentrant"},
)
weight_quantize_algo: str = field(
default=None,
metadata={
Expand Down
25 changes: 25 additions & 0 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,30 @@ def main():
if hasattr(model_config, "use_flash_attention"):
model_config.use_flash_attention = model_args.use_flash_attention

model_config.use_fused_rms_norm = model_args.use_fused_rms_norm
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
model_config.recompute_granularity = model_args.recompute_granularity
model_config.virtual_pp_degree = model_args.virtual_pp_degree
model_config.sequence_parallel = model_args.sequence_parallel
model_config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce
model_config.use_fused_rope = model_args.use_fused_rope

model_config.no_recompute_layers = model_args.no_recompute_layers
model_config.pp_recompute_interval = model_args.pp_recompute_interval
model_config.recompute_use_reentrant = model_args.recompute_use_reentrant
model_config.use_recompute = training_args.recompute

model_config.tensor_parallel_degree = training_args.tensor_parallel_degree
model_config.tensor_parallel_rank = training_args.tensor_parallel_rank

# Config for model using dropout, such as GPT.
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob

model_config.sep_parallel_degree = training_args.sep_parallel_degree
model_config.tensor_parallel_output = True
model_config.seq_length = data_args.max_length
if not training_args.autotuner_benchmark:
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
Expand Down Expand Up @@ -494,6 +518,7 @@ def compute_metrics_do_generation(eval_preds):
padding=padding,
max_label_length=max_length,
return_tensors="np",
pad_to_multiple_of=data_args.pad_to_multiple_of,
),
do_generation=data_args.eval_with_do_generation,
callbacks=[InTokensIterDatasetCallback()] if isinstance(train_ds, InTokensIterableDataset) else None,
Expand Down
47 changes: 37 additions & 10 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import math
import os
import warnings
from functools import partial
from typing import Optional, Tuple
Expand Down Expand Up @@ -75,8 +76,6 @@ def swiglu(x, y=None):

try:
if get_env_device() == "npu":
import os

from paddle.base import core

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
Expand All @@ -94,6 +93,13 @@ def swiglu(x, y=None):
]


def is_mc2_valid():
current_device = get_env_device()
if current_device == "npu":
return True
return False

Check warning on line 100 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L97-L100

Added lines #L97 - L100 were not covered by tests


def _get_interleave(n):
def _get_interleave_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
Expand Down Expand Up @@ -565,8 +571,17 @@ def __init__(self, config):
self.fuse_attention_ffn = config.fuse_attention_ffn

if config.sequence_parallel:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)):
from paddlenlp.transformers.mc2_seqence_parallel_linear import (

Check warning on line 575 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L574-L575

Added lines #L574 - L575 were not covered by tests
MC2ColumnSeqParallelLinear,
MC2RowSeqParallelLinear,
)

ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear

Check warning on line 581 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L580-L581

Added lines #L580 - L581 were not covered by tests
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear

Check warning on line 584 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L583-L584

Added lines #L583 - L584 were not covered by tests
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
Expand Down Expand Up @@ -670,7 +685,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
)

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope:
if self.use_fused_rope and get_env_device() != "npu":
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
Expand All @@ -679,8 +694,17 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
self.use_fused_rope = False

if config.sequence_parallel:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)):
from paddlenlp.transformers.mc2_seqence_parallel_linear import (

Check warning on line 698 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L697-L698

Added lines #L697 - L698 were not covered by tests
MC2ColumnSeqParallelLinear,
MC2RowSeqParallelLinear,
)

ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear

Check warning on line 704 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L703-L704

Added lines #L703 - L704 were not covered by tests
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear

Check warning on line 707 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L706-L707

Added lines #L706 - L707 were not covered by tests
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
Expand Down Expand Up @@ -1526,9 +1550,12 @@ def forward(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
if self.config.use_flash_attention:
is_casual = is_casual_mask(attention_mask)
if is_casual and alibi is None:
attention_mask = None
if get_env_device != "npu":
is_casual = is_casual_mask(attention_mask)
if is_casual and alibi is None:
attention_mask = None

Check warning on line 1556 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1553-L1556

Added lines #L1553 - L1556 were not covered by tests
else:
attention_mask = attention_mask.astype("bool")

Check warning on line 1558 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1558

Added line #L1558 was not covered by tests
hidden_states = inputs_embeds

# decoder layers
Expand Down
139 changes: 139 additions & 0 deletions paddlenlp/transformers/mc2_seqence_parallel_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

Check warning on line 15 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L15

Added line #L15 was not covered by tests

import paddle
import paddle_custom_device
from paddle import distributed as dist
from paddle.autograd import PyLayer
from paddle.distributed.fleet.utils.sequence_parallel_utils import (

Check warning on line 21 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L17-L21

Added lines #L17 - L21 were not covered by tests
ColumnSequenceParallelLinear,
RowSequenceParallelLinear,
)

__all_gather_recomputation__ = False
if int(os.getenv("MC2_Recompute", 0)):
__all_gather_recomputation__ = True

Check warning on line 28 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L26-L28

Added lines #L26 - L28 were not covered by tests


class MC2Column(PyLayer):
@staticmethod
def forward(ctx, input_, weight, group):
ctx.save_for_backward(input_, weight)

Check warning on line 34 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L31-L34

Added lines #L31 - L34 were not covered by tests

rank = dist.get_rank()
hcomm_info = group.process_group.get_comm_name(rank)

Check warning on line 37 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L36-L37

Added lines #L36 - L37 were not covered by tests

world_size = group.nranks
output, gather_out = paddle_custom_device.npu.fused_allgather_mm(

Check warning on line 40 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L39-L40

Added lines #L39 - L40 were not covered by tests
input_,
weight,
bias=None,
hcom=hcomm_info,
world_size=world_size,
gather_index=0,
gather_output=(not __all_gather_recomputation__),
comm_turn=0,
)

ctx.all_gather_output = gather_out
ctx.world_size = world_size
ctx.group = group
return output

Check warning on line 54 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L51-L54

Added lines #L51 - L54 were not covered by tests

@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensor()

Check warning on line 58 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L56-L58

Added lines #L56 - L58 were not covered by tests

if __all_gather_recomputation__:
dim_size = input_.shape
dim_size[0] = dim_size[0] * ctx.world_size
all_gather_output = paddle.empty(dim_size, dtype=input_.dtype)
all_gather_output.stop_gradient = True
all_gather_work = dist.stream.reduce_scatter(

Check warning on line 65 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L60-L65

Added lines #L60 - L65 were not covered by tests
all_gather_output, input_.contiguous(), group=ctx.group, sync_op=False
)
else:
all_gather_output = ctx.all_gather_output

Check warning on line 69 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L69

Added line #L69 was not covered by tests

grad_input = paddle.matmul(grad_output, weight, transpose_y=True)
sub_grad_input = paddle.empty(input_.shape, dtype=input_.dtype)
reduce_scatter_work = dist.stream.reduce_scatter(sub_grad_input, grad_input, group=ctx.group, sync_op=False)

Check warning on line 73 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L71-L73

Added lines #L71 - L73 were not covered by tests

if __all_gather_recomputation__:
all_gather_work.wait()

Check warning on line 76 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L75-L76

Added lines #L75 - L76 were not covered by tests

grad_weight = paddle.matmul(all_gather_output, grad_output, transpose_x=True)
reduce_scatter_work.wait()

Check warning on line 79 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L78-L79

Added lines #L78 - L79 were not covered by tests

return sub_grad_input, grad_weight

Check warning on line 81 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L81

Added line #L81 was not covered by tests


class MC2Row(PyLayer):
@staticmethod
def forward(ctx, input_, weight, group):
ctx.save_for_backward(input_, weight)

Check warning on line 87 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L84-L87

Added lines #L84 - L87 were not covered by tests

rank = dist.get_rank()
hcomm_info = group.process_group.get_comm_name(rank)
world_size = group.nranks

Check warning on line 91 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L89-L91

Added lines #L89 - L91 were not covered by tests

output = paddle_custom_device.npu.fused_mm_reduce_scatter(

Check warning on line 93 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L93

Added line #L93 was not covered by tests
input_,
weight,
bias=None,
hcom=hcomm_info,
world_size=world_size,
reduce_op="sum",
comm_turn=0,
)

ctx.hcomm_info = hcomm_info
ctx.world_size = world_size
return output

Check warning on line 105 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L103-L105

Added lines #L103 - L105 were not covered by tests

@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensor()
hcomm_info = ctx.hcomm_info
world_size = ctx.world_size

Check warning on line 111 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L107-L111

Added lines #L107 - L111 were not covered by tests

grad_input, all_gather_grad_output = paddle_custom_device.npu.fused_allgather_mm(

Check warning on line 113 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L113

Added line #L113 was not covered by tests
grad_output,
weight.t(),
bias=None,
hcom=hcomm_info,
world_size=world_size,
gather_index=0,
gather_output=True,
comm_turn=0,
)
grad_weight = paddle.matmul(input_, all_gather_grad_output, transpose_x=True)

Check warning on line 123 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L123

Added line #L123 was not covered by tests

return grad_input, grad_weight

Check warning on line 125 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L125

Added line #L125 was not covered by tests


class MC2ColumnSeqParallelLinear(ColumnSequenceParallelLinear):
def forward(self, x):
output = MC2Column.apply(x, self.weight, self.model_parallel_group)
output = output + self.bias if self.bias is not None else output
return output

Check warning on line 132 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L128-L132

Added lines #L128 - L132 were not covered by tests


class MC2RowSeqParallelLinear(RowSequenceParallelLinear):
def forward(self, x):
output = MC2Row.apply(x, self.weight, self.model_parallel_group)
output = output + self.bias if self.bias is not None else output
return output

Check warning on line 139 in paddlenlp/transformers/mc2_seqence_parallel_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mc2_seqence_parallel_linear.py#L135-L139

Added lines #L135 - L139 were not covered by tests

0 comments on commit a356945

Please sign in to comment.