From 63902e5cefe33e3788f7a790389a3d898e1bf8a3 Mon Sep 17 00:00:00 2001 From: David Date: Tue, 28 Mar 2023 16:23:23 -0700 Subject: [PATCH] added RPE + fixed RMSNorm (#6304) Signed-off-by: David Mosallanezhad Co-authored-by: David Mosallanezhad --- .../machine_translation/megatron_nmt_model.py | 14 ++++++- .../common/megatron/megatron_export.py | 39 +++++++++++++++--- nemo/utils/export_utils.py | 40 +++++++++++++++++++ 3 files changed, 86 insertions(+), 7 deletions(-) diff --git a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py index 964a018cb272..99791467cee5 100644 --- a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py +++ b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py @@ -896,11 +896,21 @@ def on_test_start(self) -> None: @property def encoder(self): - return EncEmb(self.enc_dec_model.encoder_embedding, self.enc_dec_model.enc_dec_model.encoder, self.device) + return EncEmb( + self.enc_dec_model.encoder_embedding, + self.enc_dec_model.enc_dec_model.encoder, + self.enc_dec_model.encoder_relative_position_embedding, + self.device, + ) @property def decoder(self): - return DecEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.enc_dec_model.decoder, self.device) + return DecEmb( + self.enc_dec_model.decoder_embedding, + self.enc_dec_model.enc_dec_model.decoder, + self.enc_dec_model.decoder_relative_position_embedding, + self.device, + ) @property def log_softmax(self): diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_export.py b/nemo/collections/nlp/modules/common/megatron/megatron_export.py index f196fd7987fa..a7c4cf355ea4 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_export.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_export.py @@ -86,12 +86,13 @@ class DecEmb(torch.nn.Module, Exportable): Combines decoder_embedding with the decoder component """ - def __init__(self, decoder_embedding, decoder, device): + def __init__(self, decoder_embedding, decoder, rpe, device): super(DecEmb, self).__init__() self.decoder_embedding = decoder_embedding self.decoder = decoder self.device = device + self.rpe = rpe # properties needed for export self.training = False @@ -105,8 +106,19 @@ def modules(self): def forward(self, input_ids, decoder_mask, encoder_mask, encoder_embeddings, decoder_mems): position_ids = build_position_ids(input_ids) dec_input = self.decoder_embedding(input_ids, position_ids, token_type_ids=None) + + rpe = None + if self.rpe is not None: + rpe = self.rpe(query_seq_length=input_ids.size(1), key_seq_length=input_ids.size(1),) + dec_out = ( - self.decoder(dec_input, decoder_mask, encoder_embeddings.permute(1, 0, 2), encoder_mask) + self.decoder( + dec_input, + decoder_mask, + encoder_embeddings.permute(1, 0, 2), + encoder_mask, + dec_self_attention_relative_position_bias=rpe, + ) .permute(1, 0, 2) .float() ) @@ -166,12 +178,13 @@ class EncEmb(torch.nn.Module, Exportable): Combines encoder_embedding with the encoder component """ - def __init__(self, encoder_embedding, encoder, device): + def __init__(self, encoder_embedding, encoder, rpe, device): super(EncEmb, self).__init__() self.encoder_embedding = encoder_embedding self.encoder = encoder self.device = device + self.rpe = rpe # properties needed for export self.training = False @@ -183,11 +196,27 @@ def modules(self): return (self.encoder_embedding, self.encoder) def forward(self, input_ids, encoder_mask): - position_ids = build_position_ids(input_ids) + if self.rpe is None: + position_ids = build_position_ids(input_ids) + else: + position_ids = None + enc_input = self.encoder_embedding(input_ids, position_ids, token_type_ids=None) # pass input through the encoder - return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).permute(1, 0, 2).float() + enc_seq_length = input_ids.size(1) + + rpe = None + if self.rpe is not None: + rpe = self.rpe(query_seq_length=enc_seq_length, key_seq_length=enc_seq_length,) + + return ( + self.encoder( + enc_input=enc_input, enc_attn_mask=encoder_mask, enc_self_attention_relative_position_bias=rpe + ) + .permute(1, 0, 2) + .float() + ) def input_example(self, max_batch=1, max_dim=30000, seq_len=6): seq_len = random.randint(0, 128) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 3fa7b1322aad..1946741e4c6d 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -46,6 +46,25 @@ class ExportFormat(Enum): } +class TorchRMSNorm(nn.Module): + def __init__(self, weight, eps=1e-6): + """ + LayerNorm without bias + """ + super().__init__() + self.weight = weight + self.variance_epsilon = eps + + def forward(self, hidden_states): + # can be only calculated with precision=32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + class LinearWithBiasSkip(nn.Module): def __init__(self, weight, bias, skip_bias_add): super(LinearWithBiasSkip, self).__init__() @@ -230,6 +249,7 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: else: return None + n_state = n.state_dict() mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype) n_state = n.state_dict() mod.load_state_dict(n_state) @@ -251,6 +271,25 @@ def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]: n_state = n.state_dict() mod.load_state_dict(n_state) + + return mod + + def replace_MixedFusedRMSNorm(n: nn.Module): + """ + Replaces Apex's MixedFusedRMSNorm with equivalent Pytorch layer. This is required for ONNX export. + Args: + n: the MixedFusedRMSNorm pytorch module to replace + Returns: + Equivalent module + """ + + p = next(n.parameters()) + + if isinstance(n, MixedFusedRMSNorm): + mod = TorchRMSNorm(n.state_dict()['weight'], n.eps).to(p.device) + else: + return None + return mod def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]: @@ -296,6 +335,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: "RowParallelLinear": replace_ParallelLinear, "ColumnParallelLinear": replace_ParallelLinear, "FusedScaleMaskSoftmax": replace_FusedScaleMaskSoftmax, + "MixedFusedRMSNorm": replace_MixedFusedRMSNorm, } except Exception as e: