Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 16 additions & 175 deletions optimum/habana/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import torch.nn.functional as F
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2Attention,
Gemma2Config,
Expand All @@ -36,10 +35,12 @@
from transformers.utils import logging

from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask
from ...modeling_rope_utils import GaudiRotaryEmbedding
from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module


try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa

has_fused_rope = True
except ImportError:
Expand Down Expand Up @@ -67,107 +68,10 @@
logger = logging.get_logger(__name__)


class GaudiGemma2RotaryEmbedding(torch.nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[Gemma2Config] = None,
):
super().__init__()

# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.45"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False)

def _dynamic_frequency_update(self, seq_len, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
# seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

@torch.no_grad()
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(seq_len, device=x.device)

if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

if self.attention_scaling == 1.0:
return (
self._cos_cached[:seq_len].to(dtype=x.dtype),
self._sin_cached[:seq_len].to(dtype=x.dtype),
)
else:
return (
self._cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
self._sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
)
class GaudiGemma2RotaryEmbedding(GaudiRotaryEmbedding):
def __init__(self, config: Gemma2Config):
config.rope_scaling = getattr(config, "rope_scaling", None)
super().__init__(config=config)


def gaudi_gemma2_repeat_kv(
Expand Down Expand Up @@ -196,55 +100,6 @@ def gaudi_gemma2_repeat_kv(
return query_states, key_states, value_states, attention_mask


class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.matmul(x, y)


class KVCache(torch.nn.Module):
def __init__(self):
super(KVCache, self).__init__()
self.cache = None
self.inp_seq_len = -1

def allocate(self, inp_seq_len, dtype, device, shape):
if self.cache is None or self.cache.shape != shape:
self.inp_seq_len = inp_seq_len
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert self.inp_seq_len == inp_seq_len, (
f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
)
self.cache.fill_(0)

def update(self, prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
if prev.shape == cur.shape:
prev.copy_(cur)
return orig_cur
if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
return orig_cur
assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}"
if idx is not None:
prev.index_copy_(dim, idx - 1, cur)
return prev
else:
return torch.cat((prev, cur), dim=dim)

def get_shape(self):
if self.cache is None:
return None
return self.cache.shape

def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)


def gaudi_eager_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
Expand Down Expand Up @@ -288,11 +143,7 @@ class GaudiGemma2Attention(Gemma2Attention):
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)

self.rotary_emb = GaudiGemma2RotaryEmbedding(
self.head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)
self.rotary_emb = GaudiGemma2RotaryEmbedding(config=self.config)

self.matmul_qk = Matmul()
self.matmul_av = Matmul()
Expand Down Expand Up @@ -404,7 +255,9 @@ def pre_attn_forward(
kv_seq_len = past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, kwargs["position_ids"])
query_states, key_states = apply_customized_rope(
query_states, key_states, cos, sin, kwargs["position_ids"], self.training
)

if use_cache:
# reuse k, v, self_attention
Expand Down Expand Up @@ -780,7 +633,9 @@ def forward(
past_seen_tokens,
)
else:
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -1062,23 +917,9 @@ def prepare_inputs_for_generation(
return model_inputs


def apply_customized_rope(q, k, cos, sin, position_ids):
def apply_customized_rope(q, k, cos, sin, position_ids, training=True):
if q.device.type == "hpu" and has_fused_rope:
# TODO: remove `.clone()` when it is fixed in SynapseAI
if k.dtype == torch.bfloat16:
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
k,
cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
position_ids,
)
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
)
return apply_customized_rope_module(q, k, cos, sin, position_ids, training)
else:
# keep the same implementation as Transformers v4.37.2
return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids])
Loading