Skip to content

Commit

Permalink
[LLM] Optimize llm/GPT3 performance (PaddlePaddle#8172)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarioLulab authored Apr 11, 2024
1 parent a6d3a28 commit 2900f78
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 18 deletions.
16 changes: 16 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class PreTrainingArguments(TrainingArguments):
"help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ."
},
)

# NOTE(gongenlei): new add autotuner_benchmark
autotuner_benchmark: bool = field(
default=False,
Expand Down Expand Up @@ -154,6 +155,18 @@ class ModelArguments:
default=False,
metadata={"help": "llama or other model, use_fused_rms_norm"},
)
use_fast_layer_norm: bool = field(
default=False,
metadata={"help": "GPT3 model, use fast layernorm"},
)
use_fused_linear: bool = field(
default=False,
metadata={"help": "GPT3 model, use fused linear layer"},
)
use_fused_dropout_add: bool = field(
default=False,
metadata={"help": "GPT3 model, use fused `dropout + residual add` op"},
)
fuse_attention_qkv: bool = field(
default=False,
metadata={"help": "whether to fuse attention qkv"},
Expand Down Expand Up @@ -440,6 +453,9 @@ def main():

config.use_flash_attention = model_args.use_flash_attention
config.use_fused_rms_norm = model_args.use_fused_rms_norm
config.use_fast_layer_norm = model_args.use_fast_layer_norm
config.use_fused_linear = model_args.use_fused_linear
config.use_fused_dropout_add = model_args.use_fused_dropout_add
config.fuse_attention_qkv = model_args.fuse_attention_qkv
config.fuse_attention_ffn = model_args.fuse_attention_ffn
config.recompute_granularity = model_args.recompute_granularity
Expand Down
6 changes: 4 additions & 2 deletions paddlenlp/transformers/gpt/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def __init__(
ignore_index: int = 0,
use_flash_attention: bool = False,
use_fused_dropout_add: bool = False,
fused_linear: bool = False,
use_fast_layer_norm: bool = False,
use_fused_linear: bool = False,
fuse_attention_qkv: bool = False,
fuse_attention_ffn: bool = False,
fused_softmax_with_triangular: bool = False,
Expand Down Expand Up @@ -298,7 +299,8 @@ def __init__(
self.tensor_parallel_output = tensor_parallel_output
self.output_attentions = output_attentions
self.ignore_index = ignore_index
self.fused_linear = fused_linear
self.use_fast_layer_norm = use_fast_layer_norm
self.use_fused_linear = use_fused_linear
self.use_fused_dropout_add = use_fused_dropout_add
self.fused_softmax_with_triangular = fused_softmax_with_triangular
self.virtual_pp_degree = virtual_pp_degree
Expand Down
52 changes: 42 additions & 10 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
mark_as_sequence_parallel_parameter,
)
from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from paddle.utils import try_import

from ...utils.converter import StateDictNameMapping
from ...utils.log import logger
Expand All @@ -59,6 +60,9 @@
except:
FusedDropoutAdd = None

OriginLayerNorm = paddle.nn.LayerNorm


__all__ = [
"GPTModel",
"GPTPretrainedModel",
Expand All @@ -70,6 +74,7 @@
"GPTForCausalLM",
"GPTEmbeddings",
"GPTDecoderLayer",
"GPTLayerNorm",
]


Expand Down Expand Up @@ -119,6 +124,11 @@ def seed_guard_context(name=None):
return contextlib.nullcontext()


def fast_layer_norm(input, weight, bias, eps):
fast_ln_lib = try_import("fast_ln")
return fast_ln_lib.fast_ln(input, weight, bias, eps)[0]


def _make_causal_mask(input_ids_shape, past_key_values_length):
"""
Make causal mask used for self-attention
Expand Down Expand Up @@ -149,6 +159,11 @@ def _expand_2d_mask(mask, dtype, tgt_length):
return expanded_mask


def _check_normalized_shape(normalized_shape):
if isinstance(normalized_shape, (list, tuple)):
assert len(normalized_shape) == 1


class MultiHeadAttention(nn.Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
Expand Down Expand Up @@ -196,39 +211,39 @@ def __init__(
3 * config.hidden_size,
has_bias=True,
gather_output=False,
fuse_matmul_bias=config.fused_linear,
fuse_matmul_bias=config.use_fused_linear,
)
else:
self.q_proj = ColumnParallelLinear(
config.hidden_size,
config.hidden_size,
has_bias=True,
gather_output=False,
fuse_matmul_bias=config.fused_linear,
fuse_matmul_bias=config.use_fused_linear,
)

self.k_proj = ColumnParallelLinear(
config.hidden_size,
config.hidden_size,
has_bias=True,
gather_output=False,
fuse_matmul_bias=config.fused_linear,
fuse_matmul_bias=config.use_fused_linear,
)

self.v_proj = ColumnParallelLinear(
config.hidden_size,
config.hidden_size,
has_bias=True,
gather_output=False,
fuse_matmul_bias=config.fused_linear,
fuse_matmul_bias=config.use_fused_linear,
)

self.out_proj = RowParallelLinear(
config.hidden_size,
config.hidden_size,
has_bias=True,
input_is_parallel=True,
fuse_matmul_bias=config.fused_linear,
fuse_matmul_bias=config.use_fused_linear,
)
else:
if self.config.fuse_attention_qkv:
Expand Down Expand Up @@ -421,7 +436,7 @@ def __init__(self, config, decoder_layers, norm=None, hidden_size=None):

self.config = config
self.layers = decoder_layers
self.norm = nn.LayerNorm(config.hidden_size, epsilon=1e-5)
self.norm = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5)

if config.sequence_parallel:
mark_as_sequence_parallel_parameter(self.norm.weight)
Expand Down Expand Up @@ -566,21 +581,23 @@ def __init__(self, config: GPTConfig):
config.intermediate_size,
gather_output=False,
has_bias=True,
fuse_matmul_bias=self.config.fused_linear,
fuse_matmul_bias=self.config.use_fused_linear,
)

self.linear2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
input_is_parallel=True,
has_bias=True,
fuse_matmul_bias=self.config.fused_linear,
fuse_matmul_bias=self.config.use_fused_linear,
)
else:
self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias_attr=True)
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias_attr=True)

self.norm1 = nn.LayerNorm(config.hidden_size, epsilon=1e-5)
self.norm2 = nn.LayerNorm(config.hidden_size, epsilon=1e-5)
self.norm1 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5)
self.norm2 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5)

if config.sequence_parallel:
mark_as_sequence_parallel_parameter(self.norm1.weight)
mark_as_sequence_parallel_parameter(self.norm1.bias)
Expand Down Expand Up @@ -741,6 +758,21 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None):
return embeddings


class GPTLayerNorm(OriginLayerNorm):
def __init__(self, config, normalized_shape, epsilon=1e-05, weight_attr=None, bias_attr=None, name=None):
super().__init__(
normalized_shape=normalized_shape, epsilon=epsilon, weight_attr=weight_attr, bias_attr=bias_attr
)

self.config = config
_check_normalized_shape(self._normalized_shape)

def forward(self, input):
if self.config.use_fast_layer_norm:
return fast_layer_norm(input, self.weight, self.bias, self._epsilon)
return super().forward(input)


class GPTPretrainedModel(PretrainedModel):
"""
An abstract class for pretrained GPT models. It provides GPT related
Expand Down
10 changes: 4 additions & 6 deletions paddlenlp/transformers/gpt/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import paddle
import paddle.distributed.fleet as fleet
import paddle.nn as nn
from paddle.distributed.fleet.meta_parallel import (
LayerDesc,
PipelineLayer,
Expand All @@ -30,6 +29,7 @@
GPTConfig,
GPTDecoderLayer,
GPTEmbeddings,
GPTLayerNorm,
GPTLMHead,
GPTPretrainedModel,
GPTPretrainingCriterion,
Expand Down Expand Up @@ -103,15 +103,13 @@ def forward(self, args):
embeddings = super().forward(input_ids=input_ids, position_ids=position_ids)

batch_size, seq_length = input_ids.shape
causal_mask = self.bias[:, :, 0:seq_length, :seq_length]
if attention_mask is not None:
if attention_mask.dtype != paddle.int64:
attention_mask = paddle.cast(attention_mask, dtype=paddle.int64)
if len(attention_mask.shape) == 2:
attention_mask = attention_mask[:, None, None, :]
causal_mask = self.bias[:, :, 0:seq_length, :seq_length]
attention_mask = (1.0 - (attention_mask & causal_mask)) * -1e4
else:
attention_mask = (1.0 - causal_mask) * -1e4

return return_args(embeddings, attention_mask, position_ids)

Expand All @@ -127,9 +125,9 @@ def forward(self, args):
return return_args(hidden_states, attention_mask, position_ids)


class LayerNormPipe(nn.LayerNorm):
class LayerNormPipe(GPTLayerNorm):
def __init__(self, config):
super(LayerNormPipe, self).__init__(config.hidden_size, epsilon=1e-05)
super(LayerNormPipe, self).__init__(config, config.hidden_size, epsilon=1e-05)
if config.sequence_parallel:
mark_as_sequence_parallel_parameter(self.weight)
mark_as_sequence_parallel_parameter(self.bias)
Expand Down

0 comments on commit 2900f78

Please sign in to comment.