Skip to content

Commit

Permalink
Fast glu activations (#6058)
Browse files Browse the repository at this point in the history
* fast glu activations

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

* Clean up activation list

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
MaximumEntropy and pre-commit-ci[bot] committed Mar 4, 2023
1 parent 5e38f29 commit 30db4fa
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 18 deletions.
8 changes: 4 additions & 4 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3335,7 +3335,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.activation=swiglu \
model.activation=fast-swiglu \
model.bias_activation_fusion=False \
model.hidden_dropout=0.0 \
model.attention_dropout=0.0 \
Expand Down Expand Up @@ -3371,7 +3371,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.activation=swiglu \
model.activation=fast-swiglu \
model.bias_activation_fusion=False \
model.hidden_dropout=0.0 \
model.attention_dropout=0.0 \
Expand Down Expand Up @@ -3648,7 +3648,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.decoder.num_layers=2 \
model.decoder.hidden_size=64 \
model.decoder.num_attention_heads=8 \
model.decoder.activation='swiglu' \
model.decoder.activation='fast-swiglu' \
model.decoder.masked_softmax_fusion=False \
model.decoder.bias_activation_fusion=False \
model.decoder.activations_checkpoint_method='block' \
Expand Down Expand Up @@ -3690,7 +3690,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.decoder.num_layers=2 \
model.decoder.hidden_size=64 \
model.decoder.num_attention_heads=8 \
model.decoder.activation='swiglu' \
model.decoder.activation='fast-swiglu' \
model.decoder.masked_softmax_fusion=False \
model.decoder.bias_activation_fusion=False \
model.decoder.activations_checkpoint_method='block' \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ model:
post_process: True # add pooler
persist_layer_norm: True # Use of persistent fused layer norm kernel.
bias: True # Whether to use bias terms in all weight matrices.
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu']
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu']
headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head.
transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer']
openai_gelu: False # Use OpenAI's GELU instead of the default GeLU
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropo
bias: True # Whether to use bias terms in all weight matrices.
normalization: 'layernorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm'
arch: 'transformer' # Options: ['transformer', 'perceiver']
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu']
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu']
headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head.
transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer']
hidden_steps: 32 # Number of latent vectors to use for pereceiver encoders
Expand Down
47 changes: 35 additions & 12 deletions nemo/collections/nlp/modules/common/megatron/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,32 @@ def __init__(
self.dropout = dropout
self.set_accepted_adapter_types([MLPInfusedAdapterConfig._target_])

if activation not in ['gelu', 'geglu', 'reglu', 'swiglu', 'squared-relu']:
supported_activations = [
'gelu',
'geglu',
'reglu',
'swiglu',
'squared-relu',
'fast-geglu',
'fast-swiglu',
'fast-reglu',
]

if activation not in supported_activations:
raise ValueError(
f"Activation {activation} not supported. Only gelu, geglu, reglu, swiglu, squared-relu are supported."
f"Activation {activation} not supported. Supported activations are {supported_activations}"
)

self.fast_glu_activation = activation in ['fast-geglu', 'fast-swiglu', 'fast-reglu']
no_async_tensor_model_parallel_allreduce = (
parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel
)
# Project to 4h.
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
hidden_size,
ffn_hidden_size, # NOTE: When using geglu, divide ffn dim by 2/3 to keep overall params the same.
ffn_hidden_size * 2
if self.fast_glu_activation
else ffn_hidden_size, # NOTE: When using geglu, divide ffn dim by 2/3 to keep overall params the same.
gather_output=False,
init_method=init_method,
skip_bias_add=True,
Expand All @@ -121,7 +135,14 @@ def __init__(
gradient_accumulation_fusion=gradient_accumulation_fusion,
)

self.glu_activation_family = activation in ['geglu', 'reglu', 'swiglu']
self.glu_activation_family = activation in [
'geglu',
'reglu',
'swiglu',
'fast-geglu',
'fast-reglu',
'fast-swiglu',
]
bias_activation_fusion_unavailable = activation in ['reglu', 'swiglu']

if bias_activation_fusion_unavailable and bias_activation_fusion:
Expand All @@ -144,13 +165,13 @@ def __init__(
# Give openai_gelu precedence over other activations if set, for HF compatibility. Normally this is off and shouldn't affect regular model training.
if openai_gelu:
self.activation_func = openai_gelu_func
elif activation in ["gelu", "geglu"]:
elif activation in ["gelu", "geglu", "fast-geglu"]:
self.activation_func = F.gelu
elif onnx_safe:
self.activation_func = erf_gelu
elif activation == "reglu":
elif activation in ["reglu", "fast-reglu"]:
self.activation_func = F.relu
elif activation == "swiglu":
elif activation in ["swiglu", "fast-swiglu"]:
# SiLU or sigmoid linear unit is the same as swish with beta = 1 (which is what https://arxiv.org/pdf/2002.05202.pdf uses.)
self.activation_func = F.silu
elif activation == 'squared-relu':
Expand Down Expand Up @@ -191,20 +212,22 @@ def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)

if self.glu_activation_family:
if self.fast_glu_activation:
intermediate_parallel, intermediate_parallel_2 = torch.chunk(intermediate_parallel, 2, dim=-1)
if bias_parallel is not None:
bias_parallel, bias_parallel_2 = torch.chunk(bias_parallel, 2, dim=-1)
elif self.glu_activation_family and not self.fast_glu_activation:
intermediate_parallel_2, bias_parallel_2 = self.dense_h_to_4h_2(hidden_states)

if self.bias_activation_fusion:
if self.activation == 'gelu':
intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel)
elif self.activation == 'geglu':
elif self.activation in ['geglu', 'fast-geglu']:
intermediate_parallel = fused_bias_geglu(
intermediate_parallel, bias_parallel, intermediate_parallel_2, bias_parallel_2
)

elif self.activation in ['reglu', 'swiglu'] or (
self.glu_activation_family and not self.bias_activation_fusion
):
elif self.glu_activation_family and not self.bias_activation_fusion:
if bias_parallel is not None:
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) * (
intermediate_parallel_2 + bias_parallel_2
Expand Down

0 comments on commit 30db4fa

Please sign in to comment.