Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast glu activations #6058

Merged
merged 10 commits into from
Feb 23, 2023
8 changes: 4 additions & 4 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3249,7 +3249,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.activation=swiglu \
model.activation=fast-swiglu \
model.bias_activation_fusion=False \
model.hidden_dropout=0.0 \
model.attention_dropout=0.0 \
Expand Down Expand Up @@ -3285,7 +3285,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.activation=swiglu \
model.activation=fast-swiglu \
model.bias_activation_fusion=False \
model.hidden_dropout=0.0 \
model.attention_dropout=0.0 \
Expand Down Expand Up @@ -3562,7 +3562,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.decoder.num_layers=2 \
model.decoder.hidden_size=64 \
model.decoder.num_attention_heads=8 \
model.decoder.activation='swiglu' \
model.decoder.activation='fast-swiglu' \
model.decoder.masked_softmax_fusion=False \
model.decoder.bias_activation_fusion=False \
model.decoder.activations_checkpoint_method='block' \
Expand Down Expand Up @@ -3604,7 +3604,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.decoder.num_layers=2 \
model.decoder.hidden_size=64 \
model.decoder.num_attention_heads=8 \
model.decoder.activation='swiglu' \
model.decoder.activation='fast-swiglu' \
model.decoder.masked_softmax_fusion=False \
model.decoder.bias_activation_fusion=False \
model.decoder.activations_checkpoint_method='block' \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ model:
post_process: True # add pooler
persist_layer_norm: True # Use of persistent fused layer norm kernel.
bias: True # Whether to use bias terms in all weight matrices.
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu']
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu']
headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head.
transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer']
openai_gelu: False # Use OpenAI's GELU instead of the default GeLU
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropo
bias: True # Whether to use bias terms in all weight matrices.
normalization: 'layernorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm'
arch: 'transformer' # Options: ['transformer', 'perceiver']
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu']
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu']
headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head.
transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer']
hidden_steps: 32 # Number of latent vectors to use for pereceiver encoders
Expand Down
47 changes: 35 additions & 12 deletions nemo/collections/nlp/modules/common/megatron/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,32 @@ def __init__(
self.dropout = dropout
self.set_accepted_adapter_types([MLPInfusedAdapterConfig._target_])

if activation not in ['gelu', 'geglu', 'reglu', 'swiglu', 'squared-relu']:
supported_activations = [
'gelu',
'geglu',
'reglu',
'swiglu',
'squared-relu',
'fast-geglu',
'fast-swiglu',
'fast-reglu',
]

if activation not in supported_activations:
raise ValueError(
f"Activation {activation} not supported. Only gelu, geglu, reglu, swiglu, squared-relu are supported."
f"Activation {activation} not supported. Supported activations are {supported_activations}"
)

self.fast_glu_activation = activation in ['fast-geglu', 'fast-swiglu', 'fast-reglu']
no_async_tensor_model_parallel_allreduce = (
parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel
)
# Project to 4h.
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
hidden_size,
ffn_hidden_size, # NOTE: When using geglu, divide ffn dim by 2/3 to keep overall params the same.
ffn_hidden_size * 2
if self.fast_glu_activation
else ffn_hidden_size, # NOTE: When using geglu, divide ffn dim by 2/3 to keep overall params the same.
gather_output=False,
init_method=init_method,
skip_bias_add=True,
Expand All @@ -121,7 +135,14 @@ def __init__(
gradient_accumulation_fusion=gradient_accumulation_fusion,
)

self.glu_activation_family = activation in ['geglu', 'reglu', 'swiglu']
self.glu_activation_family = activation in [
'geglu',
'reglu',
'swiglu',
'fast-geglu',
'fast-reglu',
'fast-swiglu',
]
bias_activation_fusion_unavailable = activation in ['reglu', 'swiglu']

if bias_activation_fusion_unavailable and bias_activation_fusion:
Expand All @@ -144,13 +165,13 @@ def __init__(
# Give openai_gelu precedence over other activations if set, for HF compatibility. Normally this is off and shouldn't affect regular model training.
if openai_gelu:
self.activation_func = openai_gelu_func
elif activation in ["gelu", "geglu"]:
elif activation in ["gelu", "geglu", "fast-geglu"]:
self.activation_func = F.gelu
elif onnx_safe:
self.activation_func = erf_gelu
elif activation == "reglu":
elif activation in ["reglu", "fast-reglu"]:
self.activation_func = F.relu
elif activation == "swiglu":
elif activation in ["swiglu", "fast-swiglu"]:
# SiLU or sigmoid linear unit is the same as swish with beta = 1 (which is what https://arxiv.org/pdf/2002.05202.pdf uses.)
self.activation_func = F.silu
elif activation == 'squared-relu':
Expand Down Expand Up @@ -191,20 +212,22 @@ def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)

if self.glu_activation_family:
if self.fast_glu_activation:
intermediate_parallel, intermediate_parallel_2 = torch.chunk(intermediate_parallel, 2, dim=-1)
if bias_parallel is not None:
bias_parallel, bias_parallel_2 = torch.chunk(bias_parallel, 2, dim=-1)
elif self.glu_activation_family and not self.fast_glu_activation:
intermediate_parallel_2, bias_parallel_2 = self.dense_h_to_4h_2(hidden_states)

if self.bias_activation_fusion:
if self.activation == 'gelu':
intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel)
elif self.activation == 'geglu':
elif self.activation in ['geglu', 'fast-geglu']:
intermediate_parallel = fused_bias_geglu(
intermediate_parallel, bias_parallel, intermediate_parallel_2, bias_parallel_2
)

elif self.activation in ['reglu', 'swiglu'] or (
self.glu_activation_family and not self.bias_activation_fusion
):
elif self.glu_activation_family and not self.bias_activation_fusion:
if bias_parallel is not None:
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) * (
intermediate_parallel_2 + bias_parallel_2
Expand Down