Skip to content

Commit

Permalink
added RMSNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Jun 16, 2024
1 parent d2440ac commit 79d32aa
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 15 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, `
| **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj", "rank_scaled": false}' --use-native-attention --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) |
| **Rotary Positional Embedding (RoPE)** ([Su _et al_.](https://arxiv.org/abs/2104.09864)) | Encodes absolute position with a rotation matrix and incorporates explicit relative position dependency in self-attention formulation | `--rope-args '{"max_position_embeddings": 2048, "base": 10000, "type": "vanilla"}' --use-native-attention --no-token-positional-embeddings` | [RoPE Implementation](https://github.com/jquesnelle/yarn/blob/master/scaled_rope/modeling_llama_yarn.py) |
| **Yet another RoPE extensioN method (YaRN)** ([Peng _et al_.](https://openreview.net/forum?id=wHBfxhZu1u)) | Compute-efficient method to extend the context window of models | `--yarn-args '{"max_position_embeddings": 2048, "base": 10000, "type": "vanilla", "original_max_position_embeddings": 256, "extrapolation_factor": 1, "attn_factor": 1, "beta_fast": 32, "beta_slow": 1}' --use-native-attention --no-token-positional-embeddings` | [YaRN Implementation](https://github.com/jquesnelle/yarn/blob/master/scaled_rope/modeling_llama_yarn.py) |
| **Gated FC** | Add a gating module to the Fully Connected layers in a Transformer | `--encoder-use-gated-fc --decoder-use-gated-fc` | [Gated FC Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py) |
| **Gated FC** | Add a gating module to the Fully Connected layers in a Transformer | `--encoder-use-gated-fc --decoder-use-gated-fc` | [Gated FC Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L160) |
| **RMSNorm** ([Zhang and Sennrich](https://papers.nips.cc/paper_files/paper/2019/hash/1e8a19426224ca89e83cef47f1e7f53b-Abstract.html)) | Use RMSNorm instead if LayerNorm in a Transformer | `--encoder-use-rmsnorm --decoder-use-rmsnorm` | [RMSNorm Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py#L63) |
| **Attention with Linear Biases (ALiBi)** ([Press _et al_.](https://openreview.net/forum?id=R8sQPpGCv0)) | Simple and efficient position method that biases query-key attention scores with a penalty proportional to their distance | `--alibi-args '{"alibi_asymmetrical": "false"}' --no-token-positional-embeddings --load-checkpoint-liberally` | [ALiBi Implementation](https://github.com/EIFY/fairseq) |
| **Factorized Embedding Parameterization** ([Lan _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Parameterizes large embeddings by adding an intermediate bottleneck layer | `--encoder-factorized-embed-dim $encoder_fac_embed_dim --decoder-factorized-embed-dim $decoder_fac_embed_dim --factorized-embed-activation-fn $fac_embed_activation_fn` | - |
| **Penultimate Linear Transformation Activation** | Adds activation to the penultimate linear transformation before the final projection onto the vocabulary | `--decoder-output-activation-fn $decoder_out_activation_fn` | - |
Expand Down
3 changes: 3 additions & 0 deletions fairseq/models/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class EncDecBaseConfig(FairseqDataclass):
use_gated_fc: bool = field(
default=False, metadata={"help": "use gated fc layers in the encoder/decoder"}
)
use_rmsnorm: bool = field(
default=False, metadata={"help": "use RMSNorm instead of LayerNorm"}
)
learned_pos: bool = field(
default=False, metadata={"help": "use learned positional embeddings"}
)
Expand Down
9 changes: 6 additions & 3 deletions fairseq/models/transformer/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import math
from typing import Any, Dict, List, Optional

import json
import torch
import torch.nn as nn
from torch import Tensor
Expand All @@ -21,6 +20,7 @@
FairseqDropout,
LayerDropModuleList,
LayerNorm,
RMSNorm,
PositionalEmbedding,
transformer_layer,
)
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(
else None
)
self.layernorm_embedding = (
LayerNorm(self.embed_dim, export=cfg.export)
self.normalization(self.embed_dim, rms=cfg.decoder.use_rmsnorm)
if cfg.layernorm_embedding
else None
)
Expand Down Expand Up @@ -135,7 +135,7 @@ def __init__(
self.num_layers = len(self.layers)

if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm:
self.layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.layer_norm = self.normalization(self.embed_dim, rms=cfg.decoder.use_rmsnorm)
else:
self.layer_norm = None

Expand Down Expand Up @@ -164,6 +164,9 @@ def __init__(
else:
self.alibi = None

def normalization(self, dim, rms=False):
return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim)

def build_output_projection(self, cfg, dictionary, embed_tokens):
if cfg.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
Expand Down
11 changes: 8 additions & 3 deletions fairseq/models/transformer/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import Dict, List, Optional

Expand All @@ -19,6 +18,7 @@
FairseqDropout,
LayerDropModuleList,
LayerNorm,
RMSNorm,
PositionalEmbedding,
transformer_layer,
)
Expand Down Expand Up @@ -76,7 +76,9 @@ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False):
)

self.layernorm_embedding = (
LayerNorm(embed_dim, export=cfg.export) if cfg.layernorm_embedding else None
self.normalization(embed_dim, rms=cfg.encoder.use_rmsnorm)
if cfg.layernorm_embedding
else None
)

if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
Expand Down Expand Up @@ -108,7 +110,7 @@ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False):
self.num_layers = len(self.layers)

self.layer_norm = (
LayerNorm(embed_dim, export=cfg.export)
self.normalization(embed_dim, rms=cfg.encoder.use_rmsnorm)
if cfg.encoder.normalize_before
else None
)
Expand All @@ -123,6 +125,9 @@ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False):
else:
self.alibi = None

def normalization(self, dim, rms=False):
return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim)

def build_encoder_layer(self, cfg):
layer = transformer_layer.TransformerEncoderLayerBase(
cfg, return_fc=self.return_fc
Expand Down
2 changes: 2 additions & 0 deletions fairseq/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
RelPositionalEncoding,
)

from .rms_norm import RMSNorm

__all__ = [
"AdaptiveInput",
Expand Down Expand Up @@ -115,6 +116,7 @@
"RelPositionalEncoding",
"RotaryPositionalEmbedding",
"RotaryPositionMultiHeadedAttention",
"RMSNorm",
"LinearScalingRotaryPositionalEmbedding",
"DynamicNTKScalingRotaryPositionalEmbedding",
"YaRNScaledRotaryPositionalEmbedding",
Expand Down
15 changes: 15 additions & 0 deletions fairseq/modules/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
23 changes: 15 additions & 8 deletions fairseq/modules/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fairseq.models.transformer import TransformerConfig
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from fairseq.modules import LayerNorm, MultiheadAttention, NativeMultiheadAttention
from fairseq.modules import LayerNorm, MultiheadAttention, NativeMultiheadAttention, RMSNorm


class TransformerEncoderLayerBase(nn.Module):
Expand All @@ -40,7 +40,7 @@ def __init__(self, cfg, return_fc=False):
self.quant_noise_block_size = cfg.quant_noise.pq_block_size
self.use_native_attention = cfg.use_native_attention
self.self_attn = self.build_self_attention(self.embed_dim, cfg)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.self_attn_layer_norm = self.normalization(self, self.embed_dim, rms=cfg.encoder.use_rmsnorm)

self.dropout_module = FairseqDropout(
cfg.dropout, module_name=self.__class__.__name__
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(self, cfg, return_fc=False):
else:
self.gate_fc = None

self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.final_layer_norm = self.normalization(self, self.embed_dim, rms=cfg.encoder.use_rmsnorm)

def build_fc1(self, input_dim, output_dim, bias, q_noise, qn_block_size):
return quant_noise(
Expand Down Expand Up @@ -173,6 +173,9 @@ def build_self_attention(self, embed_dim, cfg):
qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.encoder.xformers_att_config,
)

def normalization(self, dim, rms=False):
return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim)

def residual_connection(self, x, residual):
return residual + x
Expand Down Expand Up @@ -298,6 +301,7 @@ def __init__(
self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__()
self.cfg = cfg
self.embed_dim = cfg.decoder.embed_dim
self.dropout_module = FairseqDropout(
cfg.dropout, module_name=self.__class__.__name__
Expand All @@ -315,7 +319,7 @@ def __init__(
add_zero_attn=add_zero_attn,
)
self.attn_ln = (
LayerNorm(self.embed_dim)
self.normalization(self.embed_dim, rms=cfg.decoder.use_rmsnorm)
if utils.safe_getattr(cfg, "scale_attn", False)
else None
)
Expand All @@ -337,17 +341,17 @@ def __init__(
)
self.normalize_before = cfg.decoder.normalize_before

self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.self_attn_layer_norm = self.normalization(self.embed_dim, rms=cfg.decoder.use_rmsnorm)

if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.encoder_attn_layer_norm = self.normalization(self.embed_dim, rms=cfg.decoder.use_rmsnorm)

self.ffn_layernorm = (
LayerNorm(cfg.decoder.ffn_embed_dim)
self.normalization(cfg.decoder.ffn_embed_dim, rms=cfg.decoder.use_rmsnorm)
if utils.safe_getattr(cfg, "scale_fc", False)
else None
)
Expand Down Expand Up @@ -389,7 +393,7 @@ def __init__(
else:
self.gate_fc = None

self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.final_layer_norm = self.normalization(self.embed_dim, rms=cfg.decoder.use_rmsnorm)

self.need_attn = True
self.onnx_trace = False
Expand Down Expand Up @@ -463,6 +467,9 @@ def prepare_for_onnx_export_(self):
def residual_connection(self, x, residual):
return residual + x

def normalization(self, dim, rms=False):
return LayerNorm(dim, export=self.cfg.export) if not rms else RMSNorm(dim)

def forward(
self,
x,
Expand Down

0 comments on commit 79d32aa

Please sign in to comment.