Skip to content

Commit

Permalink
reverted RoPE implementation to lucidrains
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Jun 22, 2024
1 parent d96946e commit c80a5ef
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 491 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ modeling and other text generation tasks.


# Usage
This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, `LoRA` `RoPE`, `YaRN` and `ALiBi` for the `Transformer` model and the `translation` task. You can add the following flags to `fairseq-train`/`fairseq-interactive`/`fairseq-generate` to use them:
This clone of fairseq supports `Knowledge Distillation`, `Recurrent Stacking`, `LoRA` `RoPE`, and `ALiBi` for the `Transformer` model and the `translation` task. You can add the following flags to `fairseq-train`/`fairseq-interactive`/`fairseq-generate` to use them:

| **Name and Citation** | **Description** | **Flags to Activate** | **Source** |
|-----------------------|-----------------------|-----------------------|------------|
| **Knowledge Distillation** ([Hinton _et al_.](https://arxiv.org/abs/1503.02531), [Kim & Rush](https://aclanthology.org/D16-1139), [Wang _et al_.](https://aclanthology.org/2021.acl-long.504), [Gumma _et al_.](https://aclanthology.org/2023.eamt-1.11/)) | Transfers _soft_ information from a pretrained teacher model to a smaller student model | `--teacher-checkpoint-path $teacher_ckpt --task translation_with_kd --criterion label_smoothed_cross_entropy_with_kd --kd-args '{"strategy": "word_level"}'` | [Selective Distillation](https://github.com/LeslieOverfitting/selective_distillation) |
| **Recurrent Stacking** ([Dabre & Fujita](https://ojs.aaai.org/index.php/AAAI/article/view/4590)) | Extreme parameter sharing technique in which all layers in the encoder/decoder are shared | `--encoder-recurrent-stacking $encoder_recurrent_stacking --decoder-recurrent-stacking $decoder_recurrent_stacking` | - |
| **Low-Rank Adaptation (LoRA)** ([Hu _et al_.](https://openreview.net/forum?id=nZeVKeeFYf9)) | Efficient model adaptation technique that modifies a small number of model parameters while freezing the rest | `--lora-args '{"r": 8, "alpha": 16, "dropout": 0.05, "bias": "none, "target_modules": "k_proj,v_proj", "rank_scaled": false}' --use-native-attention --load-checkpoint-liberally` | [LoRA Implementation](https://github.com/microsoft/LoRA) |
| **Rotary Positional Embedding (RoPE)** ([Su _et al_.](https://arxiv.org/abs/2104.09864)) | Encodes absolute position with a rotation matrix and incorporates explicit relative position dependency in self-attention formulation | `--rope-args '{"max_position_embeddings": 2048, "base": 10000, "type": "vanilla"}' --use-native-attention --no-token-positional-embeddings` | [RoPE Implementation](https://github.com/jquesnelle/yarn/blob/master/scaled_rope/modeling_llama_yarn.py) |
| **Yet another RoPE extensioN method (YaRN)** ([Peng _et al_.](https://openreview.net/forum?id=wHBfxhZu1u)) | Compute-efficient method to extend the context window of models | `--yarn-args '{"max_position_embeddings": 2048, "base": 10000, "type": "vanilla", "original_max_position_embeddings": 256, "extrapolation_factor": 1, "attn_factor": 1, "beta_fast": 32, "beta_slow": 1}' --use-native-attention --no-token-positional-embeddings` | [YaRN Implementation](https://github.com/jquesnelle/yarn/blob/master/scaled_rope/modeling_llama_yarn.py) |
| **Rotary Positional Embedding (RoPE)** ([Su _et al_.](https://arxiv.org/abs/2104.09864)) | Encodes absolute position with a rotation matrix and incorporates explicit relative position dependency in self-attention formulation | `--rope-args '{"theta": 10000, "use_xpos": "False", "learned_freq": false}' --use-native-attention --no-token-positional-embeddings` | [RoPE Implementation](https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py) |
| **Gated FC** | Add a gating module to the Fully Connected layers in a Transformer | `--encoder-use-gated-fc --decoder-use-gated-fc` | [Gated FC Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L160) |
| **RMSNorm** ([Zhang and Sennrich](https://papers.nips.cc/paper_files/paper/2019/hash/1e8a19426224ca89e83cef47f1e7f53b-Abstract.html)) | Use RMSNorm instead if LayerNorm in a Transformer | `--encoder-use-rmsnorm --decoder-use-rmsnorm` | [RMSNorm Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py#L63) |
| **Attention with Linear Biases (ALiBi)** ([Press _et al_.](https://openreview.net/forum?id=R8sQPpGCv0)) | Simple and efficient position method that biases query-key attention scores with a penalty proportional to their distance | `--alibi-args '{"alibi_asymmetrical": "false"}' --no-token-positional-embeddings --load-checkpoint-liberally` | [ALiBi Implementation](https://github.com/EIFY/fairseq) |
Expand Down
23 changes: 10 additions & 13 deletions fairseq/models/transformer/transformer_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,6 @@ def base_architecture(args):
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
Expand All @@ -238,6 +236,9 @@ def base_architecture(args):
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)

if args.offload_activations:
args.checkpoint_activations = True


@register_model_architecture("transformer", "transformer_iwslt_de_en")
def transformer_iwslt_de_en(args):
Expand Down Expand Up @@ -296,26 +297,22 @@ def transformer_wmt_en_de_big_t2t(args):
######################################################### CUSTOM ARCHITECTURES #########################################################


@register_model_architecture("transformer", "transformer_base18L")
def _transformer_base18L(args):
@register_model_architecture("transformer", "transformer_IT2_dist")
def transformer_IT2_dist(args):
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.encoder_layers = getattr(args, "encoder_layers", 18)
args.decoder_layers = getattr(args, "decoder_layers", 18)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
base_architecture(args)


@register_model_architecture("transformer", "transformer_IT2_dist")
def transformer_base18L(args):
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
args.share_decoder_input_output_embed = getattr(args, "share_decoder_input_output_embed", True)
_transformer_base18L(args)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", True
)
base_architecture(args)


@register_model_architecture("transformer", "transformer_IT2")
def transformer_deep(args):
def transformer_IT2(args):
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.encoder_layers = getattr(args, "encoder_layers", 18)
args.decoder_layers = getattr(args, "decoder_layers", 18)
Expand Down
16 changes: 2 additions & 14 deletions fairseq/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,7 @@
RelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention,
)
from .rotary_positional_embedding import (
RotaryPositionalEmbedding,
LinearScalingRotaryPositionalEmbedding,
DynamicNTKScalingRotaryPositionalEmbedding,
YaRNScaledRotaryPositionalEmbedding,
YaRNScaledRotaryPositionalEmbedding,
DynamicYaRNScaledRotaryPositionalEmbedding,
)
from .rotary_positional_embedding import RotaryEmbedding
from .positional_encoding import (
RelPositionalEncoding,
)
Expand Down Expand Up @@ -114,12 +107,7 @@
"PositionalEmbedding",
"RelPositionMultiHeadedAttention",
"RelPositionalEncoding",
"RotaryPositionalEmbedding",
"RotaryEmbedding",
"RotaryPositionMultiHeadedAttention",
"RMSNorm",
"LinearScalingRotaryPositionalEmbedding",
"DynamicNTKScalingRotaryPositionalEmbedding",
"YaRNScaledRotaryPositionalEmbedding",
"DynamicYaRNScaledRotaryPositionalEmbedding",
"DynamicYaRNScaledRotaryPositionalEmbedding",
]
16 changes: 5 additions & 11 deletions fairseq/modules/espnet_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
import torch
from torch import nn

from fairseq.modules.rotary_positional_embedding import (
RotaryPositionalEmbedding,
apply_rotary_pos_emb,
)
from fairseq.modules.rotary_positional_embedding import RotaryEmbedding


class ESPNETMultiHeadedAttention(nn.Module):
Expand Down Expand Up @@ -214,9 +211,7 @@ def __init__(
if precision == "fp16":
precision = torch.half

self.rotary_emb = RotaryPositionalEmbedding(
self.rotary_ndims, base=rotary_emd_base, precision=precision
)
self.rotary_emb = RotaryEmbedding(self.rotary_ndims, base=rotary_emd_base)

def forward(self, query, key, value, key_padding_mask=None, **kwargs):
"""Compute rotary position attention.
Expand All @@ -235,10 +230,9 @@ def forward(self, query, key, value, key_padding_mask=None, **kwargs):
query = query.view(T, B, self.h, self.d_k)
key = key.view(T, B, self.h, self.d_k)
value = value.view(T, B, self.h, self.d_k)
cos, sin = self.rotary_emb(value, seq_len=T)
query, key = apply_rotary_pos_emb(
query, key, cos, sin, offset=0
) # offset is based on layer_past

query = self.rotary_emb.rotate_queries_or_keys(query)
key = self.rotary_emb.rotate_queries_or_keys(key)

query = query.view(T, B, self.h * self.d_k)
key = key.view(T, B, self.h * self.d_k)
Expand Down
128 changes: 25 additions & 103 deletions fairseq/modules/native_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,10 @@
from torch.nn import Parameter

from fairseq import utils

# from rotary_embedding_torch import RotaryEmbedding
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from fairseq.modules.multihead_attention import MultiheadAttention

from fairseq.modules.rotary_positional_embedding import (
apply_rotary_pos_emb,
RotaryPositionalEmbedding,
LinearScalingRotaryPositionalEmbedding,
DynamicNTKScalingRotaryPositionalEmbedding,
YaRNScaledRotaryPositionalEmbedding,
DynamicYaRNScaledRotaryPositionalEmbedding,
)
from fairseq.modules.rotary_positional_embedding import RotaryEmbedding


class NativeMultiheadAttention(MultiheadAttention):
Expand All @@ -49,7 +39,6 @@ def __init__(
q_noise=0.0,
qn_block_size=8,
rope_args=None,
yarn_args=None,
):
super().__init__(embed_dim, num_heads, dictionary=dictionary)
self.embed_dim = embed_dim
Expand All @@ -72,91 +61,26 @@ def __init__(
self.encoder_decoder_attention = encoder_decoder_attention

self.rope_args = json.loads(rope_args) if rope_args is not None else None
self.yarn_args = json.loads(yarn_args) if yarn_args is not None else None
self.yarn_pos_embed = None
self.rotary_pos_embed = None

# both self.rope_args and self.yarn_args cannot be set at the same time
assert not (
self.rope_args is not None and self.yarn_args is not None
), "Both rotary and yarn position embeddings cannot be set at the same time"

if self.rope_args is not None:
if self.rope_args["type"] == "vanilla":
self.rotary_pos_embed = RotaryPositionalEmbedding(
dim=self.head_dim,
base=self.rope_args.get("base", 10000),
max_position_embeddings=self.rope_args.get(
"max_position_embeddings", 2048
),
)
elif self.rope_args["type"] == "linear":
self.rotary_pos_embed = LinearScalingRotaryPositionalEmbedding(
dim=self.head_dim,
base=self.rope_args.get("base", 10000),
scaling_factor=self.rope_args.get("scaling_factor", 1.0),
max_position_embeddings=self.rope_args.get(
"max_position_embeddings", 2048
),
)
elif self.rope_args["type"] == "dynamic":
self.rotary_pos_embed = DynamicNTKScalingRotaryPositionalEmbedding(
dim=self.head_dim,
base=self.rope_args.get("base", 10000),
max_position_embeddings=self.rope_args.get(
"max_position_embeddings", 2048
),
)
else:
raise ValueError(
f"Unknown rotary position embedding type: {self.rope_args['type']}. Allowed types are: vanilla, linear, dynamic"
)

if self.yarn_args is not None:
if self.yarn_args["type"] == "vanilla":
self.yarn_pos_embed = YaRNScaledRotaryPositionalEmbedding(
dim=self.head_dim,
base=self.yarn_args.get("base", 10000),
scale=self.yarn_args.get("scale", 1.0),
max_position_embeddings=self.yarn_args.get(
"max_position_embeddings", 2048
),
original_max_position_embeddings=self.yarn_args.get(
"original_max_position_embeddings", 256
),
extrapolation_factor=self.yarn_args.get(
"extrapolation_factor", 1.0
),
attn_factor=self.yarn_args.get("attn_factor", 1),
beta_fast=self.yarn_args.get("beta_fast", 32),
beta_slow=self.yarn_args.get("beta_slow", 1),
)
elif self.yarn_args["type"] == "dynamic":
self.yarn_pos_embed = DynamicYaRNScaledRotaryPositionalEmbedding(
dim=self.head_dim,
base=self.yarn_args.get("base", 10000),
max_position_embeddings=self.yarn_args.get(
"max_position_embeddings", 2048
),
original_max_position_embeddings=self.yarn_args.get(
"original_max_position_embeddings", 256
),
extrapolation_factor=self.yarn_args.get(
"extrapolation_factor", 1.0
),
attn_factor=self.yarn_args.get("attn_factor", 1),
beta_fast=self.yarn_args.get("beta_fast", 32),
beta_slow=self.yarn_args.get("beta_slow", 1),
finetuned=self.yarn_args.get("finetuned", False),
)
else:
raise ValueError(
f"Unknown rotary position embedding type: {self.yarn_args['type']}. Allowed types are: vanilla, dynamic"
)
self.rotary_pos_embed = RotaryEmbedding(
dim=(self.head_dim // 2),
freqs_for="lang",
cache_if_possible=True,
theta=self.rope_args.get("theta", 10000),
learned_freq=self.rope_args.get("learned_freq", False),
use_xpos=self.rope_args.get("use_xpos", False),
xpos_scale_base=self.rope_args.get("xpos_scale_base", 512),
interpolate_factor=self.rope_args.get("interpolate_factor", 1.0),
theta_rescale_factor=self.rope_args.get("theta_rescale_factor", 1.0),
)
self.use_xpos = self.rope_args.get("use_xpos", False)
else:
self.rotary_pos_embed = None

assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
assert (
not self.self_attention or self.qkv_same_dim
), "Self-attention requires query, key and value to be of the same size"

self.k_proj = quant_noise(
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
Expand Down Expand Up @@ -341,22 +265,20 @@ def forward(
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)

assert k is not None
assert k.size(1) == src_len

if self.rotary_pos_embed is not None or self.yarn_pos_embed is not None:
if self.rotary_pos_embed is not None:
# q shape: [bsz * num_heads, tgt_len, head_dim]
q_ = q.view(kv_bsz, self.num_heads, -1, self.head_dim)
k_ = k.view(kv_bsz, self.num_heads, -1, self.head_dim)

# this is mutually exclusive
cos, sin = (
self.rotary_pos_embed(q_, seq_len=q_.shape[2])
if self.rotary_pos_embed is not None
else self.yarn_pos_embed(q_, seq_len=q_.shape[2])
)

q_, k_ = apply_rotary_pos_emb(q_, k_, cos, sin)
if not self.use_xpos:
q_ = self.rotary_pos_embed.rotate_queries_or_keys(q_)
k_ = self.rotary_pos_embed.rotate_queries_or_keys(k_)
else:
q_, k_ = self.rotary_pos_embed.rotate_queries_and_keys(q_, k_)

# reshape back to [bsz * num_heads, tgt_len, head_dim]
q = q_.view(kv_bsz * self.num_heads, -1, self.head_dim)
Expand Down
3 changes: 2 additions & 1 deletion fairseq/modules/rms_norm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torch.nn as nn


class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
Expand Down
Loading

0 comments on commit c80a5ef

Please sign in to comment.