Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ def _add_network_size_args(parser):
default=PositionEmbeddingType.absolute,
help='Define position embedding type ("absolute" | "rotary"). "absolute" by default.'
)
group.add_argument('--glu-activation', type=str,
choices=["liglu", "geglu", "reglu", "swiglu"],
Comment thread
jaketae marked this conversation as resolved.
Outdated
default="",
Comment thread
jaketae marked this conversation as resolved.
Outdated
help='GLU activations to use.'
)

return parser

Expand Down
File renamed without changes.
11 changes: 10 additions & 1 deletion megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import deepspeed

from .glu_activations import geglu, liglu, reglu, swiglu
from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb

# flags required to enable jit fusion kernels
Expand Down Expand Up @@ -76,7 +77,15 @@ def __init__(self, init_method, output_layer_init_method):

self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
if args.openai_gelu:
if args.glu_activation:
glu_lookup = {
Comment thread
jaketae marked this conversation as resolved.
Outdated
"gegelu": geglu,
"liglu": liglu,
"reglu": reglu,
"swiglu": swiglu,
}
self.activation_func = glu_lookup[args.glu_activation]
elif args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
Expand Down