Skip to content
Open

fope #4043

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions lmdeploy/pytorch/backends/default/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import math

import torch
import torch.nn.functional as F
from torch import nn

from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters, RopeType, RotaryEmbeddingBuilder,
RotaryEmbeddingImpl, YarnParameters)
from ..rotary_embedding import (FopeParameters, Llama3Parameters, LongRoPEScalingParameters, RopeType,
RotaryEmbeddingBuilder, RotaryEmbeddingImpl, YarnParameters)


def _rotary_embedding_fwd(position_ids: torch.Tensor,
Expand Down Expand Up @@ -270,6 +271,64 @@ def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
device_type=device)


class FopeRotaryEmbeddingImpl(RotaryEmbeddingImpl):

def __init__(self,
dim: int,
max_position_embeddings: int = 4096,
scaling_factor: float = 1.0,
params: FopeParameters = None):
super().__init__(dim, scaling_factor=scaling_factor)
self.head_dim = dim
self.max_position_embeddings = max_position_embeddings
self.attention_scaling = scaling_factor
self.params = params

inv_freq = self.params.inv_freq
inv_freq_idx_selected = inv_freq > 2 * torch.pi / self.max_position_embeddings
if self.params.num_inv_freq is not None and inv_freq_idx_selected.sum() > (inv_freq.shape[-1] -
self.params.num_inv_freq):
inv_freq_idx_selected[-self.params.num_inv_freq:] = False
self.inv_freq = inv_freq[inv_freq_idx_selected]
self.register_buffer('inv_freq', self.inv_freq, persistent=False)

def forward(self, x: torch.Tensor, position_ids: torch.Tensor, sin_coef: torch.Tensor, cos_coef: torch.Tensor):
"""forward."""
if self.inv_freq.device != x.device:
self.inv_freq = self.inv_freq.to(x.device)

inv_freq = self.inv_freq
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)

batch_size, seq_len, _ = x.shape
if self.params.fope_sep_head:
pos_cos = freqs.cos().unsqueeze(1).expand(batch_size, self.params.num_key_value_heads, seq_len, -1)
pos_sin = freqs.sin().unsqueeze(1).expand(batch_size, self.params.num_key_value_heads, seq_len, -1)
else:
pos_cos = freqs.cos()
pos_sin = freqs.sin()

if self.params.fope_sep_head:
sin = torch.einsum('bhtD, hDd -> bthd', pos_sin, sin_coef.float())
cos = torch.einsum('bhtD, hDd -> bthd', pos_cos, cos_coef.float())
else:
sin = torch.einsum('btD, Dd -> btd', pos_sin, sin_coef.float())
cos = torch.einsum('btD, Dd -> btd', pos_cos, cos_coef.float())

sin = F.pad(input=sin, pad=(0, self.head_dim // 2 - sin.size(-1)), mode='constant', value=1)
cos = F.pad(input=cos, pad=(0, self.head_dim // 2 - cos.size(-1)), mode='constant', value=1)

sin = torch.cat((sin, sin), dim=-1)
cos = torch.cat((cos, cos), dim=-1)

cos = cos * self.attention_scaling
sin = sin * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class DefaultRotaryEmbeddingBuilder(RotaryEmbeddingBuilder):
"""Rotary embedding builder."""

Expand All @@ -282,6 +341,7 @@ def build(
yarn_params: YarnParameters = None,
longrope_params: LongRoPEScalingParameters = None,
llama3_params: Llama3Parameters = None,
fope_params: FopeParameters = None,
emb_type: RopeType = RopeType.Default,
):
"""build."""
Expand All @@ -302,5 +362,12 @@ def build(
max_position_embeddings=max_position_embeddings,
longrope_params=longrope_params,
)
elif emb_type == RopeType.Fope:
return FopeRotaryEmbeddingImpl(
dim,
max_position_embeddings=max_position_embeddings,
scaling_factor=scaling_factor,
params=fope_params,
)
else:
raise NotImplementedError(f'Unsupported embedding type: {emb_type}')
15 changes: 14 additions & 1 deletion lmdeploy/pytorch/backends/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from enum import Enum, auto
from typing import List

import torch


class RopeType(Enum):
"""Rotary embedding type."""
Expand All @@ -13,6 +15,7 @@ class RopeType(Enum):
Llama3 = auto()
Yarn = auto()
LongRoPEScaling = auto()
Fope = auto()


@dataclass
Expand Down Expand Up @@ -43,11 +46,20 @@ class Llama3Parameters:
original_max_position_embeddings: int = 8192


@dataclass
class FopeParameters:
"""Fope parameters."""
num_inv_freq: int = None
num_key_value_heads: int = 1
fope_sep_head: bool = False
inv_freq: torch.Tensor = None


class RotaryEmbeddingImpl(ABC):
"""Rotary embedding implementation api."""

@abstractmethod
def forward(self, x, position_ids):
def forward(self, x, position_ids, **kwargs):
"""forward."""
raise NotImplementedError

Expand All @@ -65,6 +77,7 @@ def build(
yarn_params: YarnParameters = None,
longrope_params: LongRoPEScalingParameters = None,
llama3_params: Llama3Parameters = None,
fope_params: FopeParameters = None,
emb_type: RopeType = RopeType.Default,
):
"""build."""
Expand Down
115 changes: 111 additions & 4 deletions lmdeploy/pytorch/nn/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import math

import torch
from torch import Tensor, nn
from transformers import PretrainedConfig

from ..backends import OpType, get_backend
from ..backends.rotary_embedding import Llama3Parameters, LongRoPEScalingParameters, RopeType, YarnParameters
from ..backends.rotary_embedding import (FopeParameters, Llama3Parameters, LongRoPEScalingParameters, RopeType,
YarnParameters)


def _get_default_rope_parameters(config: PretrainedConfig):
Expand Down Expand Up @@ -92,6 +94,15 @@ def _get_llama3_parameters(config: PretrainedConfig):
return dict(emb_type=RopeType.Llama3, scaling_factor=scaling_factor, llama3_params=params)


def _get_fope_parameters(config: PretrainedConfig):
"""Get fope parameters."""
params = FopeParameters()
params.num_inv_freq = config.num_inv_freq
params.num_key_value_heads = config.num_key_value_heads
params.fope_sep_head = config.fope_sep_head
return dict(use_fope=True, fope_params=params)


def build_rotary_params(config: PretrainedConfig):
"""Get scaling_factor rotary params, and emb_type."""
params = dict(emb_type=RopeType.Default)
Expand All @@ -114,6 +125,9 @@ def build_rotary_params(config: PretrainedConfig):
if partial_rotary_factor is not None:
params['partial_rotary_factor'] = partial_rotary_factor

if getattr(config, 'use_fope', False):
params.update(_get_fope_parameters(config))

return params


Expand All @@ -124,8 +138,10 @@ def build_rotary_embedding(dim: int,
yarn_params: YarnParameters = None,
longrope_params: LongRoPEScalingParameters = None,
llama3_params: Llama3Parameters = None,
fope_params: FopeParameters = None,
emb_type: RopeType = RopeType.Default,
partial_rotary_factor: float = None) -> nn.Module:
partial_rotary_factor: float = None,
use_fope: bool = False) -> nn.Module:
"""Build rotary embedding op."""
backend = get_backend()

Expand All @@ -134,7 +150,7 @@ def build_rotary_embedding(dim: int,
# update rope_dim
if partial_rotary_factor is not None:
dim = int(dim * partial_rotary_factor)
return builder.build(dim,
impl = builder.build(dim,
max_position_embeddings,
base,
scaling_factor,
Expand All @@ -143,6 +159,15 @@ def build_rotary_embedding(dim: int,
llama3_params=llama3_params,
emb_type=emb_type)

if use_fope:
assert fope_params is not None, 'fope_params should not be None when use_fope is True.'
inv_freq = impl.inv_freq
fope_params.inv_freq = inv_freq
fope = FopeRotaryEmbedding(dim, max_position_embeddings, scaling_factor, fope_params)
return fope

return impl


def build_rotary_embedding_from_config(config: PretrainedConfig) -> nn.Module:
"""Build rotary embedding op from config."""
Expand All @@ -169,4 +194,86 @@ def __init__(self):

def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):
"""forward."""
return self.impl.forward(query, key, cos, sin, inplace)

assert query.dim() == key.dim() == 3, 'Expected query key (seq_len, heads, head_dim)'
assert cos.dim() <= 3 and sin.dim() <= 3

need_reshape = False
if cos.dim() == 3:
# for fope
need_reshape = True
query_shape = query.shape
key_shape = key.shape
cos = cos.flatten(0, 1)
sin = sin.flatten(0, 1)
seq_len = cos.size(0)
query = query.view(seq_len, -1, query.size(-1))
key = key.view(seq_len, -1, key.size(-1))

query, key = self.impl.forward(query, key, cos, sin, inplace)

if need_reshape:
query = query.view(query_shape)
key = key.view(key_shape)
return query, key


class FopeRotaryEmbedding(nn.Module):
"""Fope rotary embedding."""

def __init__(self, dim: int, max_position_embeddings: int, attention_scaling: float, params: FopeParameters):
super().__init__()

num_key_value_heads, tp = self.update_num_kv_heads(params.num_key_value_heads)
self.tp = tp
params.num_key_value_heads = num_key_value_heads

# build impl
backend = get_backend()
builder = backend.get_layer_impl_builder(OpType.RotaryEmbedding)
self.impl = builder.build(dim,
max_position_embeddings=max_position_embeddings,
scaling_factor=attention_scaling,
fope_params=params,
emb_type=RopeType.Fope)

# setup params
inv_freq = self.impl.inv_freq
self.input_dim = inv_freq.shape[-1]
self.output_dim = inv_freq.shape[-1]
self.cos_coef = nn.Parameter(torch.empty(num_key_value_heads, self.input_dim, self.output_dim),
requires_grad=False)
self.sin_coef = nn.Parameter(torch.empty(num_key_value_heads, self.input_dim, self.output_dim),
requires_grad=False)
if self.tp:
self.cos_coef.weight_loader = self.weight_loader
self.sin_coef.weight_loader = self.weight_loader

@staticmethod
def update_num_kv_heads(num_key_value_heads: int):
"""Update num_key_value_heads."""
from lmdeploy.pytorch.distributed import get_dist_manager
dist_mgr = get_dist_manager()
dist_ctx = dist_mgr.current_context()
tp = dist_ctx.dist_config.attn_config.tp
if tp > 1:
num_key_value_heads = max(1, num_key_value_heads // tp)
return num_key_value_heads, tp

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
"""Weight loader."""
from lmdeploy.pytorch.distributed import get_tp_world_rank
world_size, rank = get_tp_world_rank()
num_key_value_heads = loaded_weight.size(0)

if num_key_value_heads < world_size:
n_replicate = world_size // num_key_value_heads
world_size = num_key_value_heads
rank = rank // n_replicate

loaded_weight = loaded_weight.chunk(world_size, dim=0)[rank]
param.copy_(loaded_weight)

def forward(self, x: Tensor, position_ids: Tensor):
"""forward."""
return self.impl.forward(x, position_ids, sin_coef=self.sin_coef, cos_coef=self.cos_coef)