diff --git a/optimum/habana/transformers/models/gemma2/modeling_gemma2.py b/optimum/habana/transformers/models/gemma2/modeling_gemma2.py index 7a29e24571..06a202e08c 100755 --- a/optimum/habana/transformers/models/gemma2/modeling_gemma2.py +++ b/optimum/habana/transformers/models/gemma2/modeling_gemma2.py @@ -23,7 +23,6 @@ import torch.nn.functional as F from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.models.gemma2.modeling_gemma2 import ( Gemma2Attention, Gemma2Config, @@ -36,10 +35,12 @@ from transformers.utils import logging from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask +from ...modeling_rope_utils import GaudiRotaryEmbedding +from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module try: - from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa has_fused_rope = True except ImportError: @@ -67,107 +68,10 @@ logger = logging.get_logger(__name__) -class GaudiGemma2RotaryEmbedding(torch.nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[Gemma2Config] = None, - ): - super().__init__() - - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.45" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) - - def _dynamic_frequency_update(self, seq_len, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - # seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(seq_len, device=x.device) - - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - if self.attention_scaling == 1.0: - return ( - self._cos_cached[:seq_len].to(dtype=x.dtype), - self._sin_cached[:seq_len].to(dtype=x.dtype), - ) - else: - return ( - self._cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self._sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) +class GaudiGemma2RotaryEmbedding(GaudiRotaryEmbedding): + def __init__(self, config: Gemma2Config): + config.rope_scaling = getattr(config, "rope_scaling", None) + super().__init__(config=config) def gaudi_gemma2_repeat_kv( @@ -196,55 +100,6 @@ def gaudi_gemma2_repeat_kv( return query_states, key_states, value_states, attention_mask -class Matmul(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.matmul(x, y) - - -class KVCache(torch.nn.Module): - def __init__(self): - super(KVCache, self).__init__() - self.cache = None - self.inp_seq_len = -1 - - def allocate(self, inp_seq_len, dtype, device, shape): - if self.cache is None or self.cache.shape != shape: - self.inp_seq_len = inp_seq_len - self.cache = torch.zeros(shape, dtype=dtype, device=device) - else: - assert self.inp_seq_len == inp_seq_len, ( - f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - ) - self.cache.fill_(0) - - def update(self, prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - return prev - else: - return torch.cat((prev, cur), dim=dim) - - def get_shape(self): - if self.cache is None: - return None - return self.cache.shape - - def forward(self, cur, dim, idx): - return self.update(self.cache, cur, dim, idx, self.inp_seq_len) - - def gaudi_eager_attention_forward( module: torch.nn.Module, query: torch.Tensor, @@ -288,11 +143,7 @@ class GaudiGemma2Attention(Gemma2Attention): def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.rotary_emb = GaudiGemma2RotaryEmbedding( - self.head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) + self.rotary_emb = GaudiGemma2RotaryEmbedding(config=self.config) self.matmul_qk = Matmul() self.matmul_av = Matmul() @@ -404,7 +255,9 @@ def pre_attn_forward( kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, kwargs["position_ids"]) + query_states, key_states = apply_customized_rope( + query_states, key_states, cos, sin, kwargs["position_ids"], self.training + ) if use_cache: # reuse k, v, self_attention @@ -780,7 +633,9 @@ def forward( past_seen_tokens, ) else: - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) # embed positions hidden_states = inputs_embeds @@ -1062,23 +917,9 @@ def prepare_inputs_for_generation( return model_inputs -def apply_customized_rope(q, k, cos, sin, position_ids): +def apply_customized_rope(q, k, cos, sin, position_ids, training=True): if q.device.type == "hpu" and has_fused_rope: - # TODO: remove `.clone()` when it is fixed in SynapseAI - if k.dtype == torch.bfloat16: - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, - cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), - sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), - position_ids, - ) - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ) + return apply_customized_rope_module(q, k, cos, sin, position_ids, training) else: # keep the same implementation as Transformers v4.37.2 return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids])