Skip to content

Commit

Permalink
added RPE + fixed RMSNorm (#6304)
Browse files Browse the repository at this point in the history
Signed-off-by: David Mosallanezhad <[email protected]>
Co-authored-by: David Mosallanezhad <[email protected]>
  • Loading branch information
2 people authored and web-flow committed Mar 28, 2023
1 parent 10ceb1c commit 63902e5
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -896,11 +896,21 @@ def on_test_start(self) -> None:

@property
def encoder(self):
return EncEmb(self.enc_dec_model.encoder_embedding, self.enc_dec_model.enc_dec_model.encoder, self.device)
return EncEmb(
self.enc_dec_model.encoder_embedding,
self.enc_dec_model.enc_dec_model.encoder,
self.enc_dec_model.encoder_relative_position_embedding,
self.device,
)

@property
def decoder(self):
return DecEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.enc_dec_model.decoder, self.device)
return DecEmb(
self.enc_dec_model.decoder_embedding,
self.enc_dec_model.enc_dec_model.decoder,
self.enc_dec_model.decoder_relative_position_embedding,
self.device,
)

@property
def log_softmax(self):
Expand Down
39 changes: 34 additions & 5 deletions nemo/collections/nlp/modules/common/megatron/megatron_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,13 @@ class DecEmb(torch.nn.Module, Exportable):
Combines decoder_embedding with the decoder component
"""

def __init__(self, decoder_embedding, decoder, device):
def __init__(self, decoder_embedding, decoder, rpe, device):
super(DecEmb, self).__init__()

self.decoder_embedding = decoder_embedding
self.decoder = decoder
self.device = device
self.rpe = rpe

# properties needed for export
self.training = False
Expand All @@ -105,8 +106,19 @@ def modules(self):
def forward(self, input_ids, decoder_mask, encoder_mask, encoder_embeddings, decoder_mems):
position_ids = build_position_ids(input_ids)
dec_input = self.decoder_embedding(input_ids, position_ids, token_type_ids=None)

rpe = None
if self.rpe is not None:
rpe = self.rpe(query_seq_length=input_ids.size(1), key_seq_length=input_ids.size(1),)

dec_out = (
self.decoder(dec_input, decoder_mask, encoder_embeddings.permute(1, 0, 2), encoder_mask)
self.decoder(
dec_input,
decoder_mask,
encoder_embeddings.permute(1, 0, 2),
encoder_mask,
dec_self_attention_relative_position_bias=rpe,
)
.permute(1, 0, 2)
.float()
)
Expand Down Expand Up @@ -166,12 +178,13 @@ class EncEmb(torch.nn.Module, Exportable):
Combines encoder_embedding with the encoder component
"""

def __init__(self, encoder_embedding, encoder, device):
def __init__(self, encoder_embedding, encoder, rpe, device):
super(EncEmb, self).__init__()

self.encoder_embedding = encoder_embedding
self.encoder = encoder
self.device = device
self.rpe = rpe

# properties needed for export
self.training = False
Expand All @@ -183,11 +196,27 @@ def modules(self):
return (self.encoder_embedding, self.encoder)

def forward(self, input_ids, encoder_mask):
position_ids = build_position_ids(input_ids)
if self.rpe is None:
position_ids = build_position_ids(input_ids)
else:
position_ids = None

enc_input = self.encoder_embedding(input_ids, position_ids, token_type_ids=None)

# pass input through the encoder
return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).permute(1, 0, 2).float()
enc_seq_length = input_ids.size(1)

rpe = None
if self.rpe is not None:
rpe = self.rpe(query_seq_length=enc_seq_length, key_seq_length=enc_seq_length,)

return (
self.encoder(
enc_input=enc_input, enc_attn_mask=encoder_mask, enc_self_attention_relative_position_bias=rpe
)
.permute(1, 0, 2)
.float()
)

def input_example(self, max_batch=1, max_dim=30000, seq_len=6):
seq_len = random.randint(0, 128)
Expand Down
40 changes: 40 additions & 0 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ class ExportFormat(Enum):
}


class TorchRMSNorm(nn.Module):
def __init__(self, weight, eps=1e-6):
"""
LayerNorm without bias
"""
super().__init__()
self.weight = weight
self.variance_epsilon = eps

def forward(self, hidden_states):
# can be only calculated with precision=32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states


class LinearWithBiasSkip(nn.Module):
def __init__(self, weight, bias, skip_bias_add):
super(LinearWithBiasSkip, self).__init__()
Expand Down Expand Up @@ -230,6 +249,7 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
else:
return None

n_state = n.state_dict()
mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype)
n_state = n.state_dict()
mod.load_state_dict(n_state)
Expand All @@ -251,6 +271,25 @@ def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]:

n_state = n.state_dict()
mod.load_state_dict(n_state)

return mod

def replace_MixedFusedRMSNorm(n: nn.Module):
"""
Replaces Apex's MixedFusedRMSNorm with equivalent Pytorch layer. This is required for ONNX export.
Args:
n: the MixedFusedRMSNorm pytorch module to replace
Returns:
Equivalent module
"""

p = next(n.parameters())

if isinstance(n, MixedFusedRMSNorm):
mod = TorchRMSNorm(n.state_dict()['weight'], n.eps).to(p.device)
else:
return None

return mod

def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
Expand Down Expand Up @@ -296,6 +335,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
"RowParallelLinear": replace_ParallelLinear,
"ColumnParallelLinear": replace_ParallelLinear,
"FusedScaleMaskSoftmax": replace_FusedScaleMaskSoftmax,
"MixedFusedRMSNorm": replace_MixedFusedRMSNorm,
}

except Exception as e:
Expand Down

0 comments on commit 63902e5

Please sign in to comment.