Skip to content
2 changes: 1 addition & 1 deletion examples/pretrain_gpt3_175B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ options=" \
--num-attention-heads 96 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 1 \
--micro-batch-size 1 \
--global-batch-size 1536 \
--rampup-batch-size 16 16 5859375 \
--train-samples 146484375 \
Expand Down
24 changes: 18 additions & 6 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import torch
import deepspeed

from megatron.enums import PositionEmbeddingType


def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
Expand Down Expand Up @@ -199,8 +202,7 @@ def parse_args(extra_args_provider=None, defaults={},
'and lr-warmup-samples'

# Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings']
required_args = ['num_layers', 'hidden_size', 'num_attention_heads']
for req_arg in required_args:
_check_arg_is_not_none(args, req_arg)

Expand All @@ -219,10 +221,15 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.encoder_seq_length is not None
args.seq_length = args.encoder_seq_length

if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
if args.decoder_seq_length is not None:
assert args.max_position_embeddings >= args.decoder_seq_length
if args.position_embedding_type == PositionEmbeddingType.absolute:
assert args.max_position_embeddings is not None
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
if args.decoder_seq_length is not None:
assert args.max_position_embeddings >= args.decoder_seq_length
else:
assert args.max_position_embeddings is None

if args.lr is not None:
assert args.min_lr <= args.lr
if args.save is not None:
Expand Down Expand Up @@ -301,6 +308,11 @@ def _add_network_size_args(parser):
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x],
choices=list(PositionEmbeddingType),
default=PositionEmbeddingType.absolute,
help='Define position embedding type ("absolute" | "rotary"). "absolute" by default.'
)

return parser

Expand Down
1 change: 1 addition & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def _compare(arg_name, old_arg_name=None):
_compare('hidden_size')
_compare('num_attention_heads')
_compare('max_position_embeddings')
_compare('position_embedding_type')
if args.vocab_file:
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
Expand Down
4 changes: 4 additions & 0 deletions megatron/model/enums.py → megatron/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ class AttnType(enum.Enum):
class AttnMaskType(enum.Enum):
padding = 1
causal = 2

class PositionEmbeddingType(enum.Enum):
rotary = 1
absolute = 2
2 changes: 1 addition & 1 deletion megatron/model/bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from megatron import get_args
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model import LayerNorm
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/biencoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from megatron.checkpointing import get_checkpoint_name
from megatron import mpu, get_tokenizer
from megatron.model.bert_model import bert_position_ids
from megatron.model.enums import AttnMaskType
from megatron.enums import AttnMaskType
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import torch
from megatron.model.enums import AttnMaskType
from megatron.enums import AttnMaskType
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This move surprises me. Did you move enums up a folder and if so, why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically this caused circular dependency, when I added my enum in that file, and imported it in arguments.py. The reason why, is when you import from megatron.model.enums import AttnMaskType you'd execute megatron/model/__init__.py which means you're importing a bunch of code, some of which import arguments.py.
https://stackoverflow.com/questions/24302754/python-submodule-imports-using-init-py/24303380

In order to remove this dependency I've moved enums outside model, as it's safe to say that importing enums should not be linked to model. Please see the section Changes in the PR description.



class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
Expand Down
4 changes: 1 addition & 3 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

from megatron import get_args
from megatron import mpu
from megatron.enums import AttnMaskType
from .module import MegatronModule, fp32_to_float16

from .enums import AttnMaskType
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
Expand Down Expand Up @@ -182,7 +182,6 @@ def _to_float16(inputs):
EmbeddingPipe,
args.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
init_method=init_method,
num_tokentypes=num_tokentypes,
Expand Down Expand Up @@ -224,7 +223,6 @@ def _logits_helper(embedding, lm_output):
EmbeddingPipe,
args.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
init_method=init_method,
num_tokentypes=num_tokentypes,
Expand Down
62 changes: 35 additions & 27 deletions megatron/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
from megatron.enums import LayerType, AttnMaskType, PositionEmbeddingType
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal
Expand Down Expand Up @@ -107,8 +107,6 @@ class Embedding(MegatronModule):
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
Expand All @@ -118,7 +116,6 @@ class Embedding(MegatronModule):
def __init__(self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method,
num_tokentypes=0):
Expand All @@ -137,11 +134,17 @@ def __init__(self,
self._word_embeddings_key = 'word_embeddings'

# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
self.position_embedding_type = args.position_embedding_type
if self.position_embedding_type == PositionEmbeddingType.absolute:
max_position_embeddings = args.max_position_embeddings
assert max_position_embeddings is not None
self.position_embeddings = torch.nn.Embedding(
max_position_embeddings, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
else:
self.position_embeddings = None

# Token type embedding.
# Add this as an optional field that can be added through
Expand Down Expand Up @@ -179,8 +182,14 @@ def add_tokentype_embeddings(self, num_tokentypes):
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
embeddings = words_embeddings

if self.position_embedding_type == PositionEmbeddingType.absolute:
assert self.position_embeddings is not None
embeddings = embeddings + self.position_embeddings(position_ids)
else:
assert self.position_embeddings is None

if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
Expand All @@ -199,9 +208,10 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
state_dict_ = {}
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
if self.position_embedding_type == PositionEmbeddingType.absolute:
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(
Expand All @@ -225,16 +235,17 @@ def load_state_dict(self, state_dict, strict=True):
self.word_embeddings.load_state_dict(state_dict_, strict=strict)

# Position embedding.
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
if self.position_embedding_type == PositionEmbeddingType.absolute:
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)

# Tokentype embedding.
if self.num_tokentypes > 0:
Expand Down Expand Up @@ -295,8 +306,6 @@ class TransformerLanguageModel(MegatronModule):
Arguments:
transformer_hparams: transformer hyperparameters
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
Expand Down Expand Up @@ -329,7 +338,6 @@ def __init__(self,
if self.pre_process:
self.embedding = Embedding(self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes)
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/multiple_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
Expand Down
51 changes: 51 additions & 0 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Extracted from: https://github.com/EleutherAI/gpt-neox
import torch


class RotaryEmbedding(torch.nn.Module):

def __init__(self, dim, base=10000, precision=torch.half):
super().__init__()
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision

def forward(self, x, seq_dim=1, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
# [sx, 1 (b * np), hn]
self.cos_cached = emb.cos()[:, None, :]
self.sin_cached = emb.sin()[:, None, :]
if self.precision == torch.bfloat16:
self.cos_cached = self.cos_cached.bfloat16()
self.sin_cached = self.sin_cached.bfloat16()
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]


# rotary pos emb helpers:

def rotate_half(x):
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions


@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like having helper functions - but is there a reason why we can't run them from the forward?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I'm not familiar with torch.jit.script but I guess it compiles it? I admit that the code was copy pasted from EleutherAI with some little modifications. I'm okay with removing this helper if you want.

cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...]
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)


def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16
cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...]
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
2 changes: 1 addition & 1 deletion megatron/model/realm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from megatron.model import BertModel
from .module import MegatronModule
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.enums import AttnMaskType
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
get_args,
mpu
)
from megatron.model.enums import AttnMaskType
from megatron.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits, get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import (
Expand Down
Loading