diff --git a/examples/pretrain_gpt3_175B.sh b/examples/pretrain_gpt3_175B.sh index ad0d244d7..88761b555 100755 --- a/examples/pretrain_gpt3_175B.sh +++ b/examples/pretrain_gpt3_175B.sh @@ -23,7 +23,7 @@ options=" \ --num-attention-heads 96 \ --seq-length 2048 \ --max-position-embeddings 2048 \ - --micro-batch-size 1 \ + --micro-batch-size 1 \ --global-batch-size 1536 \ --rampup-batch-size 16 16 5859375 \ --train-samples 146484375 \ diff --git a/megatron/arguments.py b/megatron/arguments.py index 0c2811117..761641b05 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -21,6 +21,9 @@ import torch import deepspeed +from megatron.enums import PositionEmbeddingType + + def parse_args(extra_args_provider=None, defaults={}, ignore_unknown_args=False): """Parse all arguments.""" @@ -199,8 +202,7 @@ def parse_args(extra_args_provider=None, defaults={}, 'and lr-warmup-samples' # Check required arguments. - required_args = ['num_layers', 'hidden_size', 'num_attention_heads', - 'max_position_embeddings'] + required_args = ['num_layers', 'hidden_size', 'num_attention_heads'] for req_arg in required_args: _check_arg_is_not_none(args, req_arg) @@ -219,10 +221,15 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.encoder_seq_length is not None args.seq_length = args.encoder_seq_length - if args.seq_length is not None: - assert args.max_position_embeddings >= args.seq_length - if args.decoder_seq_length is not None: - assert args.max_position_embeddings >= args.decoder_seq_length + if args.position_embedding_type == PositionEmbeddingType.absolute: + assert args.max_position_embeddings is not None + if args.seq_length is not None: + assert args.max_position_embeddings >= args.seq_length + if args.decoder_seq_length is not None: + assert args.max_position_embeddings >= args.decoder_seq_length + else: + assert args.max_position_embeddings is None + if args.lr is not None: assert args.min_lr <= args.lr if args.save is not None: @@ -301,6 +308,11 @@ def _add_network_size_args(parser): group.add_argument('--bert-no-binary-head', action='store_false', help='Disable BERT binary head.', dest='bert_binary_head') + group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x], + choices=list(PositionEmbeddingType), + default=PositionEmbeddingType.absolute, + help='Define position embedding type ("absolute" | "rotary"). "absolute" by default.' + ) return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 3cc6a8e2e..829fb1101 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -61,6 +61,7 @@ def _compare(arg_name, old_arg_name=None): _compare('hidden_size') _compare('num_attention_heads') _compare('max_position_embeddings') + _compare('position_embedding_type') if args.vocab_file: _compare('make_vocab_size_divisible_by') _compare('padded_vocab_size') diff --git a/megatron/model/enums.py b/megatron/enums.py similarity index 91% rename from megatron/model/enums.py rename to megatron/enums.py index b6992fefa..84d9a4ff0 100644 --- a/megatron/model/enums.py +++ b/megatron/enums.py @@ -26,3 +26,7 @@ class AttnType(enum.Enum): class AttnMaskType(enum.Enum): padding = 1 causal = 2 + +class PositionEmbeddingType(enum.Enum): + rotary = 1 + absolute = 2 diff --git a/megatron/model/bert_model.py b/megatron/model/bert_model.py index 3ff5039d5..4cb650b36 100644 --- a/megatron/model/bert_model.py +++ b/megatron/model/bert_model.py @@ -19,7 +19,7 @@ from megatron import get_args from megatron import mpu -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import get_language_model from megatron.model import LayerNorm diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py index 51ac0a060..0f0a6698f 100644 --- a/megatron/model/biencoder_model.py +++ b/megatron/model/biencoder_model.py @@ -8,7 +8,7 @@ from megatron.checkpointing import get_checkpoint_name from megatron import mpu, get_tokenizer from megatron.model.bert_model import bert_position_ids -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType from megatron.model.language_model import get_language_model from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal diff --git a/megatron/model/classification.py b/megatron/model/classification.py index d975072f7..94dc5fe7d 100644 --- a/megatron/model/classification.py +++ b/megatron/model/classification.py @@ -19,7 +19,7 @@ from megatron import get_args, print_rank_last from megatron import mpu -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids from megatron.model.language_model import get_language_model from megatron.model.utils import get_linear_layer diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index 097b29ef4..aa1384207 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -14,7 +14,7 @@ # limitations under the License. import torch -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 293526465..8a1d189e0 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -20,9 +20,9 @@ from megatron import get_args from megatron import mpu +from megatron.enums import AttnMaskType from .module import MegatronModule, fp32_to_float16 -from .enums import AttnMaskType from .language_model import parallel_lm_logits from .language_model import get_language_model from .utils import init_method_normal @@ -182,7 +182,6 @@ def _to_float16(inputs): EmbeddingPipe, args.hidden_size, args.padded_vocab_size, - args.max_position_embeddings, args.hidden_dropout, init_method=init_method, num_tokentypes=num_tokentypes, @@ -224,7 +223,6 @@ def _logits_helper(embedding, lm_output): EmbeddingPipe, args.hidden_size, args.padded_vocab_size, - args.max_position_embeddings, args.hidden_dropout, init_method=init_method, num_tokentypes=num_tokentypes, diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index dd5f6972b..1022164ec 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -21,7 +21,7 @@ from megatron import get_args from megatron import mpu from .module import MegatronModule -from megatron.model.enums import LayerType, AttnMaskType +from megatron.enums import LayerType, AttnMaskType, PositionEmbeddingType from megatron.model.transformer import ParallelTransformer from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal, scaled_init_method_normal @@ -107,8 +107,6 @@ class Embedding(MegatronModule): Arguments: hidden_size: hidden size vocab_size: vocabulary size - max_sequence_length: maximum size of sequence. This - is used for positional embedding embedding_dropout_prob: dropout probability for embeddings init_method: weight initialization method num_tokentypes: size of the token-type embeddings. 0 value @@ -118,7 +116,6 @@ class Embedding(MegatronModule): def __init__(self, hidden_size, vocab_size, - max_sequence_length, embedding_dropout_prob, init_method, num_tokentypes=0): @@ -137,11 +134,17 @@ def __init__(self, self._word_embeddings_key = 'word_embeddings' # Position embedding (serial). - self.position_embeddings = torch.nn.Embedding( - max_sequence_length, self.hidden_size) - self._position_embeddings_key = 'position_embeddings' - # Initialize the position embeddings. - self.init_method(self.position_embeddings.weight) + self.position_embedding_type = args.position_embedding_type + if self.position_embedding_type == PositionEmbeddingType.absolute: + max_position_embeddings = args.max_position_embeddings + assert max_position_embeddings is not None + self.position_embeddings = torch.nn.Embedding( + max_position_embeddings, self.hidden_size) + self._position_embeddings_key = 'position_embeddings' + # Initialize the position embeddings. + self.init_method(self.position_embeddings.weight) + else: + self.position_embeddings = None # Token type embedding. # Add this as an optional field that can be added through @@ -179,8 +182,14 @@ def add_tokentype_embeddings(self, num_tokentypes): def forward(self, input_ids, position_ids, tokentype_ids=None): # Embeddings. words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - embeddings = words_embeddings + position_embeddings + embeddings = words_embeddings + + if self.position_embedding_type == PositionEmbeddingType.absolute: + assert self.position_embeddings is not None + embeddings = embeddings + self.position_embeddings(position_ids) + else: + assert self.position_embeddings is None + if tokentype_ids is not None: assert self.tokentype_embeddings is not None embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) @@ -199,9 +208,10 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='', state_dict_ = {} state_dict_[self._word_embeddings_key] \ = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) + if self.position_embedding_type == PositionEmbeddingType.absolute: + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict( + destination, prefix, keep_vars) if self.num_tokentypes > 0: state_dict_[self._tokentype_embeddings_key] \ = self.tokentype_embeddings.state_dict( @@ -225,16 +235,17 @@ def load_state_dict(self, state_dict, strict=True): self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. - if self._position_embeddings_key in state_dict: - state_dict_ = state_dict[self._position_embeddings_key] - else: - # for backward compatibility. - state_dict_ = {} - for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] - self.position_embeddings.load_state_dict(state_dict_, strict=strict) + if self.position_embedding_type == PositionEmbeddingType.absolute: + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. if self.num_tokentypes > 0: @@ -295,8 +306,6 @@ class TransformerLanguageModel(MegatronModule): Arguments: transformer_hparams: transformer hyperparameters vocab_size: vocabulary size - max_sequence_length: maximum size of sequence. This - is used for positional embedding embedding_dropout_prob: dropout probability for embeddings num_tokentypes: size of the token-type embeddings. 0 value will ignore this embedding @@ -329,7 +338,6 @@ def __init__(self, if self.pre_process: self.embedding = Embedding(self.hidden_size, args.padded_vocab_size, - args.max_position_embeddings, args.hidden_dropout, self.init_method, self.num_tokentypes) diff --git a/megatron/model/multiple_choice.py b/megatron/model/multiple_choice.py index c43bd969c..a8445e0cb 100644 --- a/megatron/model/multiple_choice.py +++ b/megatron/model/multiple_choice.py @@ -19,7 +19,7 @@ from megatron import get_args, print_rank_last from megatron import mpu -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids from megatron.model.language_model import get_language_model from megatron.model.utils import get_linear_layer diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py new file mode 100644 index 000000000..3494f9e4e --- /dev/null +++ b/megatron/model/positional_embeddings.py @@ -0,0 +1,51 @@ +# Extracted from: https://github.com/EleutherAI/gpt-neox +import torch + + +class RotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, base=10000, precision=torch.half): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + # [sx, 1 (b * np), hn] + self.cos_cached = emb.cos()[:, None, :] + self.sin_cached = emb.sin()[:, None, :] + if self.precision == torch.bfloat16: + self.cos_cached = self.cos_cached.bfloat16() + self.sin_cached = self.sin_cached.bfloat16() + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + +# rotary pos emb helpers: + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16 + cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) \ No newline at end of file diff --git a/megatron/model/realm_model.py b/megatron/model/realm_model.py index 5730a85e3..c57f51592 100644 --- a/megatron/model/realm_model.py +++ b/megatron/model/realm_model.py @@ -6,7 +6,7 @@ from megatron.model import BertModel from .module import MegatronModule from megatron import mpu -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal from megatron.model.language_model import get_language_model diff --git a/megatron/model/t5_model.py b/megatron/model/t5_model.py index beb4f0ee5..7b59b9e18 100644 --- a/megatron/model/t5_model.py +++ b/megatron/model/t5_model.py @@ -21,7 +21,7 @@ get_args, mpu ) -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType from megatron.model.language_model import parallel_lm_logits, get_language_model from megatron.model.transformer import LayerNorm from megatron.model.utils import ( diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 47433d725..c8055015c 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -17,11 +17,12 @@ import math import torch import torch.nn.functional as F +from torch import nn from megatron import get_args from megatron import mpu from .module import MegatronModule -from megatron.model.enums import AttnMaskType, LayerType, AttnType +from megatron.enums import AttnMaskType, LayerType, AttnType, PositionEmbeddingType from megatron.model import LayerNorm from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl @@ -29,6 +30,8 @@ import deepspeed +from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb + # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) @@ -119,6 +122,7 @@ def __init__(self, init_method, args = get_args() self.fp16 = args.fp16 self.bf16 = args.bf16 + self.position_embedding_type = args.position_embedding_type self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 @@ -192,6 +196,9 @@ def __init__(self, init_method, get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint + if self.position_embedding_type == PositionEmbeddingType.rotary: + self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head, precision=args.params_dtype) + def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None): # hidden_states: [sq, b, h] @@ -274,6 +281,18 @@ def forward(self, hidden_states, attention_mask, layer_past=None, dtype=query_layer.dtype, device=torch.cuda.current_device()) + # Rotary embeddings + if self.position_embedding_type == PositionEmbeddingType.rotary: + apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb + + seq_len = key_layer.shape[0] + offset = 0 + if layer_past is not None and layer_past.numel() > 0: + offset = layer_past[0].shape[0] + seq_len += offset + cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) + query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) + # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_result, diff --git a/megatron/mpu/random.py b/megatron/mpu/random.py index a1c1a4c71..180ed2c51 100644 --- a/megatron/mpu/random.py +++ b/megatron/mpu/random.py @@ -42,10 +42,14 @@ def init_checkpointed_activations_memory_buffer(): - """Initializ the memory buffer for the checkpointed activations.""" + """Initialize the memory buffer for the checkpointed activations.""" args = get_args() - per_layer = args.micro_batch_size * args.max_position_embeddings * \ + upper_bound_sequence_length = max( + args.seq_length if args.seq_length is not None else 0, + args.decoder_seq_length if args.decoder_seq_length is not None else 0 + ) + per_layer = args.micro_batch_size * upper_bound_sequence_length * \ args.hidden_size // args.tensor_model_parallel_size assert args.num_layers % args.checkpoint_num_layers == 0, \ 'number of layers is not divisible by checkpoint-num-layers'