diff --git a/Jenkinsfile b/Jenkinsfile index 9932fa3b7777..0a3e4d331ddf 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -3335,7 +3335,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.optim.sched.min_lr=8e-5 \ model.max_position_embeddings=128 \ model.encoder_seq_length=128 \ - model.activation=swiglu \ + model.activation=fast-swiglu \ model.bias_activation_fusion=False \ model.hidden_dropout=0.0 \ model.attention_dropout=0.0 \ @@ -3371,7 +3371,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.optim.sched.min_lr=8e-5 \ model.max_position_embeddings=128 \ model.encoder_seq_length=128 \ - model.activation=swiglu \ + model.activation=fast-swiglu \ model.bias_activation_fusion=False \ model.hidden_dropout=0.0 \ model.attention_dropout=0.0 \ @@ -3648,7 +3648,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.decoder.num_layers=2 \ model.decoder.hidden_size=64 \ model.decoder.num_attention_heads=8 \ - model.decoder.activation='swiglu' \ + model.decoder.activation='fast-swiglu' \ model.decoder.masked_softmax_fusion=False \ model.decoder.bias_activation_fusion=False \ model.decoder.activations_checkpoint_method='block' \ @@ -3690,7 +3690,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.decoder.num_layers=2 \ model.decoder.hidden_size=64 \ model.decoder.num_attention_heads=8 \ - model.decoder.activation='swiglu' \ + model.decoder.activation='fast-swiglu' \ model.decoder.masked_softmax_fusion=False \ model.decoder.bias_activation_fusion=False \ model.decoder.activations_checkpoint_method='block' \ diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 94bd4abcac57..3f1199fc15d8 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -71,7 +71,7 @@ model: post_process: True # add pooler persist_layer_norm: True # Use of persistent fused layer norm kernel. bias: True # Whether to use bias terms in all weight matrices. - activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu'] + activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu'] headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head. transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer'] openai_gelu: False # Use OpenAI's GELU instead of the default GeLU diff --git a/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml b/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml index a1dd20c5c468..b623d08e4e8b 100644 --- a/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml @@ -21,7 +21,7 @@ bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropo bias: True # Whether to use bias terms in all weight matrices. normalization: 'layernorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm' arch: 'transformer' # Options: ['transformer', 'perceiver'] -activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu'] +activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu'] headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head. transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer'] hidden_steps: 32 # Number of latent vectors to use for pereceiver encoders diff --git a/nemo/collections/nlp/modules/common/megatron/mlp.py b/nemo/collections/nlp/modules/common/megatron/mlp.py index cf9dbc6286f4..f0a67a4ea348 100644 --- a/nemo/collections/nlp/modules/common/megatron/mlp.py +++ b/nemo/collections/nlp/modules/common/megatron/mlp.py @@ -83,18 +83,32 @@ def __init__( self.dropout = dropout self.set_accepted_adapter_types([MLPInfusedAdapterConfig._target_]) - if activation not in ['gelu', 'geglu', 'reglu', 'swiglu', 'squared-relu']: + supported_activations = [ + 'gelu', + 'geglu', + 'reglu', + 'swiglu', + 'squared-relu', + 'fast-geglu', + 'fast-swiglu', + 'fast-reglu', + ] + + if activation not in supported_activations: raise ValueError( - f"Activation {activation} not supported. Only gelu, geglu, reglu, swiglu, squared-relu are supported." + f"Activation {activation} not supported. Supported activations are {supported_activations}" ) + self.fast_glu_activation = activation in ['fast-geglu', 'fast-swiglu', 'fast-reglu'] no_async_tensor_model_parallel_allreduce = ( parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel ) # Project to 4h. self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( hidden_size, - ffn_hidden_size, # NOTE: When using geglu, divide ffn dim by 2/3 to keep overall params the same. + ffn_hidden_size * 2 + if self.fast_glu_activation + else ffn_hidden_size, # NOTE: When using geglu, divide ffn dim by 2/3 to keep overall params the same. gather_output=False, init_method=init_method, skip_bias_add=True, @@ -121,7 +135,14 @@ def __init__( gradient_accumulation_fusion=gradient_accumulation_fusion, ) - self.glu_activation_family = activation in ['geglu', 'reglu', 'swiglu'] + self.glu_activation_family = activation in [ + 'geglu', + 'reglu', + 'swiglu', + 'fast-geglu', + 'fast-reglu', + 'fast-swiglu', + ] bias_activation_fusion_unavailable = activation in ['reglu', 'swiglu'] if bias_activation_fusion_unavailable and bias_activation_fusion: @@ -144,13 +165,13 @@ def __init__( # Give openai_gelu precedence over other activations if set, for HF compatibility. Normally this is off and shouldn't affect regular model training. if openai_gelu: self.activation_func = openai_gelu_func - elif activation in ["gelu", "geglu"]: + elif activation in ["gelu", "geglu", "fast-geglu"]: self.activation_func = F.gelu elif onnx_safe: self.activation_func = erf_gelu - elif activation == "reglu": + elif activation in ["reglu", "fast-reglu"]: self.activation_func = F.relu - elif activation == "swiglu": + elif activation in ["swiglu", "fast-swiglu"]: # SiLU or sigmoid linear unit is the same as swish with beta = 1 (which is what https://arxiv.org/pdf/2002.05202.pdf uses.) self.activation_func = F.silu elif activation == 'squared-relu': @@ -191,20 +212,22 @@ def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) - if self.glu_activation_family: + if self.fast_glu_activation: + intermediate_parallel, intermediate_parallel_2 = torch.chunk(intermediate_parallel, 2, dim=-1) + if bias_parallel is not None: + bias_parallel, bias_parallel_2 = torch.chunk(bias_parallel, 2, dim=-1) + elif self.glu_activation_family and not self.fast_glu_activation: intermediate_parallel_2, bias_parallel_2 = self.dense_h_to_4h_2(hidden_states) if self.bias_activation_fusion: if self.activation == 'gelu': intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel) - elif self.activation == 'geglu': + elif self.activation in ['geglu', 'fast-geglu']: intermediate_parallel = fused_bias_geglu( intermediate_parallel, bias_parallel, intermediate_parallel_2, bias_parallel_2 ) - elif self.activation in ['reglu', 'swiglu'] or ( - self.glu_activation_family and not self.bias_activation_fusion - ): + elif self.glu_activation_family and not self.bias_activation_fusion: if bias_parallel is not None: intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) * ( intermediate_parallel_2 + bias_parallel_2