diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index 8ac94415b509..d0df32321e18 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -75,6 +75,7 @@ class PreTrainingArguments(TrainingArguments): "help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ." }, ) + # NOTE(gongenlei): new add autotuner_benchmark autotuner_benchmark: bool = field( default=False, @@ -154,6 +155,18 @@ class ModelArguments: default=False, metadata={"help": "llama or other model, use_fused_rms_norm"}, ) + use_fast_layer_norm: bool = field( + default=False, + metadata={"help": "GPT3 model, use fast layernorm"}, + ) + use_fused_linear: bool = field( + default=False, + metadata={"help": "GPT3 model, use fused linear layer"}, + ) + use_fused_dropout_add: bool = field( + default=False, + metadata={"help": "GPT3 model, use fused `dropout + residual add` op"}, + ) fuse_attention_qkv: bool = field( default=False, metadata={"help": "whether to fuse attention qkv"}, @@ -440,6 +453,9 @@ def main(): config.use_flash_attention = model_args.use_flash_attention config.use_fused_rms_norm = model_args.use_fused_rms_norm + config.use_fast_layer_norm = model_args.use_fast_layer_norm + config.use_fused_linear = model_args.use_fused_linear + config.use_fused_dropout_add = model_args.use_fused_dropout_add config.fuse_attention_qkv = model_args.fuse_attention_qkv config.fuse_attention_ffn = model_args.fuse_attention_ffn config.recompute_granularity = model_args.recompute_granularity diff --git a/paddlenlp/transformers/gpt/configuration.py b/paddlenlp/transformers/gpt/configuration.py index 1c0645bd8a14..250a291d43ff 100644 --- a/paddlenlp/transformers/gpt/configuration.py +++ b/paddlenlp/transformers/gpt/configuration.py @@ -257,7 +257,8 @@ def __init__( ignore_index: int = 0, use_flash_attention: bool = False, use_fused_dropout_add: bool = False, - fused_linear: bool = False, + use_fast_layer_norm: bool = False, + use_fused_linear: bool = False, fuse_attention_qkv: bool = False, fuse_attention_ffn: bool = False, fused_softmax_with_triangular: bool = False, @@ -298,7 +299,8 @@ def __init__( self.tensor_parallel_output = tensor_parallel_output self.output_attentions = output_attentions self.ignore_index = ignore_index - self.fused_linear = fused_linear + self.use_fast_layer_norm = use_fast_layer_norm + self.use_fused_linear = use_fused_linear self.use_fused_dropout_add = use_fused_dropout_add self.fused_softmax_with_triangular = fused_softmax_with_triangular self.virtual_pp_degree = virtual_pp_degree diff --git a/paddlenlp/transformers/gpt/modeling.py b/paddlenlp/transformers/gpt/modeling.py index 167879c46256..c83466e33041 100644 --- a/paddlenlp/transformers/gpt/modeling.py +++ b/paddlenlp/transformers/gpt/modeling.py @@ -37,6 +37,7 @@ mark_as_sequence_parallel_parameter, ) from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from paddle.utils import try_import from ...utils.converter import StateDictNameMapping from ...utils.log import logger @@ -59,6 +60,9 @@ except: FusedDropoutAdd = None +OriginLayerNorm = paddle.nn.LayerNorm + + __all__ = [ "GPTModel", "GPTPretrainedModel", @@ -70,6 +74,7 @@ "GPTForCausalLM", "GPTEmbeddings", "GPTDecoderLayer", + "GPTLayerNorm", ] @@ -119,6 +124,11 @@ def seed_guard_context(name=None): return contextlib.nullcontext() +def fast_layer_norm(input, weight, bias, eps): + fast_ln_lib = try_import("fast_ln") + return fast_ln_lib.fast_ln(input, weight, bias, eps)[0] + + def _make_causal_mask(input_ids_shape, past_key_values_length): """ Make causal mask used for self-attention @@ -149,6 +159,11 @@ def _expand_2d_mask(mask, dtype, tgt_length): return expanded_mask +def _check_normalized_shape(normalized_shape): + if isinstance(normalized_shape, (list, tuple)): + assert len(normalized_shape) == 1 + + class MultiHeadAttention(nn.Layer): """ Attention mapps queries and a set of key-value pairs to outputs, and @@ -196,7 +211,7 @@ def __init__( 3 * config.hidden_size, has_bias=True, gather_output=False, - fuse_matmul_bias=config.fused_linear, + fuse_matmul_bias=config.use_fused_linear, ) else: self.q_proj = ColumnParallelLinear( @@ -204,7 +219,7 @@ def __init__( config.hidden_size, has_bias=True, gather_output=False, - fuse_matmul_bias=config.fused_linear, + fuse_matmul_bias=config.use_fused_linear, ) self.k_proj = ColumnParallelLinear( @@ -212,7 +227,7 @@ def __init__( config.hidden_size, has_bias=True, gather_output=False, - fuse_matmul_bias=config.fused_linear, + fuse_matmul_bias=config.use_fused_linear, ) self.v_proj = ColumnParallelLinear( @@ -220,7 +235,7 @@ def __init__( config.hidden_size, has_bias=True, gather_output=False, - fuse_matmul_bias=config.fused_linear, + fuse_matmul_bias=config.use_fused_linear, ) self.out_proj = RowParallelLinear( @@ -228,7 +243,7 @@ def __init__( config.hidden_size, has_bias=True, input_is_parallel=True, - fuse_matmul_bias=config.fused_linear, + fuse_matmul_bias=config.use_fused_linear, ) else: if self.config.fuse_attention_qkv: @@ -421,7 +436,7 @@ def __init__(self, config, decoder_layers, norm=None, hidden_size=None): self.config = config self.layers = decoder_layers - self.norm = nn.LayerNorm(config.hidden_size, epsilon=1e-5) + self.norm = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5) if config.sequence_parallel: mark_as_sequence_parallel_parameter(self.norm.weight) @@ -566,21 +581,23 @@ def __init__(self, config: GPTConfig): config.intermediate_size, gather_output=False, has_bias=True, - fuse_matmul_bias=self.config.fused_linear, + fuse_matmul_bias=self.config.use_fused_linear, ) + self.linear2 = RowParallelLinear( config.intermediate_size, config.hidden_size, input_is_parallel=True, has_bias=True, - fuse_matmul_bias=self.config.fused_linear, + fuse_matmul_bias=self.config.use_fused_linear, ) else: self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias_attr=True) self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias_attr=True) - self.norm1 = nn.LayerNorm(config.hidden_size, epsilon=1e-5) - self.norm2 = nn.LayerNorm(config.hidden_size, epsilon=1e-5) + self.norm1 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5) + self.norm2 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5) + if config.sequence_parallel: mark_as_sequence_parallel_parameter(self.norm1.weight) mark_as_sequence_parallel_parameter(self.norm1.bias) @@ -741,6 +758,21 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None): return embeddings +class GPTLayerNorm(OriginLayerNorm): + def __init__(self, config, normalized_shape, epsilon=1e-05, weight_attr=None, bias_attr=None, name=None): + super().__init__( + normalized_shape=normalized_shape, epsilon=epsilon, weight_attr=weight_attr, bias_attr=bias_attr + ) + + self.config = config + _check_normalized_shape(self._normalized_shape) + + def forward(self, input): + if self.config.use_fast_layer_norm: + return fast_layer_norm(input, self.weight, self.bias, self._epsilon) + return super().forward(input) + + class GPTPretrainedModel(PretrainedModel): """ An abstract class for pretrained GPT models. It provides GPT related diff --git a/paddlenlp/transformers/gpt/modeling_pp.py b/paddlenlp/transformers/gpt/modeling_pp.py index fb4946febc46..3ec6b004edee 100644 --- a/paddlenlp/transformers/gpt/modeling_pp.py +++ b/paddlenlp/transformers/gpt/modeling_pp.py @@ -13,7 +13,6 @@ # limitations under the License. import paddle import paddle.distributed.fleet as fleet -import paddle.nn as nn from paddle.distributed.fleet.meta_parallel import ( LayerDesc, PipelineLayer, @@ -30,6 +29,7 @@ GPTConfig, GPTDecoderLayer, GPTEmbeddings, + GPTLayerNorm, GPTLMHead, GPTPretrainedModel, GPTPretrainingCriterion, @@ -103,15 +103,13 @@ def forward(self, args): embeddings = super().forward(input_ids=input_ids, position_ids=position_ids) batch_size, seq_length = input_ids.shape - causal_mask = self.bias[:, :, 0:seq_length, :seq_length] if attention_mask is not None: if attention_mask.dtype != paddle.int64: attention_mask = paddle.cast(attention_mask, dtype=paddle.int64) if len(attention_mask.shape) == 2: attention_mask = attention_mask[:, None, None, :] + causal_mask = self.bias[:, :, 0:seq_length, :seq_length] attention_mask = (1.0 - (attention_mask & causal_mask)) * -1e4 - else: - attention_mask = (1.0 - causal_mask) * -1e4 return return_args(embeddings, attention_mask, position_ids) @@ -127,9 +125,9 @@ def forward(self, args): return return_args(hidden_states, attention_mask, position_ids) -class LayerNormPipe(nn.LayerNorm): +class LayerNormPipe(GPTLayerNorm): def __init__(self, config): - super(LayerNormPipe, self).__init__(config.hidden_size, epsilon=1e-05) + super(LayerNormPipe, self).__init__(config, config.hidden_size, epsilon=1e-05) if config.sequence_parallel: mark_as_sequence_parallel_parameter(self.weight) mark_as_sequence_parallel_parameter(self.bias)