diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 148f749fd3..701128ec1e 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -36,7 +36,9 @@ GaudiGPTNeoXForCausalLM, GaudiLlamaAttention, GaudiLlamaDecoderLayer, + GaudiLlamaDynamicNTKScalingRotaryEmbedding, GaudiLlamaForCausalLM, + GaudiLlamaLinearScalingRotaryEmbedding, GaudiLlamaMLP, GaudiLlamaModel, GaudiLlamaRotaryEmbedding, @@ -269,7 +271,10 @@ def adapt_transformers_to_gaudi(): transformers.models.llama.modeling_llama.LlamaMLP = GaudiLlamaMLP transformers.models.llama.modeling_llama.LlamaDecoderLayer = GaudiLlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = GaudiLlamaRotaryEmbedding - + transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = GaudiLlamaLinearScalingRotaryEmbedding + transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding = ( + GaudiLlamaDynamicNTKScalingRotaryEmbedding + ) transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward # Optimization for falcon generation on Gaudi diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 6880350aa4..fff777f72a 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -71,7 +71,9 @@ from .llama import ( GaudiLlamaAttention, GaudiLlamaDecoderLayer, + GaudiLlamaDynamicNTKScalingRotaryEmbedding, GaudiLlamaForCausalLM, + GaudiLlamaLinearScalingRotaryEmbedding, GaudiLlamaMLP, GaudiLlamaModel, GaudiLlamaRotaryEmbedding, diff --git a/optimum/habana/transformers/models/llama/__init__.py b/optimum/habana/transformers/models/llama/__init__.py index ab46467a9d..20703ffd09 100644 --- a/optimum/habana/transformers/models/llama/__init__.py +++ b/optimum/habana/transformers/models/llama/__init__.py @@ -1,7 +1,9 @@ from .modeling_llama import ( GaudiLlamaAttention, GaudiLlamaDecoderLayer, + GaudiLlamaDynamicNTKScalingRotaryEmbedding, GaudiLlamaForCausalLM, + GaudiLlamaLinearScalingRotaryEmbedding, GaudiLlamaMLP, GaudiLlamaModel, GaudiLlamaRotaryEmbedding, diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index a388b3b470..922eef50a1 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -14,7 +14,6 @@ LlamaMLP, LlamaModel, LlamaRMSNorm, - LlamaRotaryEmbedding, apply_rotary_pos_emb, logger, ) @@ -189,7 +188,22 @@ def forward(self, cur, dim, idx): return update(self.cache, cur, dim, idx, self.inp_seq_len) -class GaudiLlamaRotaryEmbedding(LlamaRotaryEmbedding): +class GaudiLlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) @@ -197,8 +211,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self._cos_cached = emb.cos().to(dtype) - self._sin_cached = emb.sin().to(dtype) + self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -211,6 +225,39 @@ def forward(self, x, seq_len=None): ) +class GaudiLlamaLinearScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) + + +class GaudiLlamaDynamicNTKScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) + + class GaudiLlamaAttention(LlamaAttention): def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) @@ -230,13 +277,12 @@ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): self.v_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) def update_sincos_cache(self, seq_len): - if self.config.rope_scaling is None: - # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings - # This helps in avoiding creation of these caches during actual model forward pass and - # reduce memory consumption and improve performance. - if seq_len > self.max_position_embeddings: - self.max_position_embeddings = seq_len - _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) def reorder(self, tensor, beam_idx, dim_a, dim_b): updated = tensor.index_select(0, beam_idx) @@ -319,11 +365,10 @@ def pre_attn_forward( kv_seq_len = past_key_value[0][-2] else: kv_seq_len = past_key_value[0].shape[-2] - if self.config.rope_scaling is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - else: - cos, sin = self.rotary_emb(value_states, position_ids) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None or reuse_cache: # reuse k, v, self_attention if reuse_cache: @@ -661,7 +706,10 @@ def forward( # HPU specific mask generation if ignore_cache_position: causal_mask = _gaudi_prepare_4d_causal_attention_mask( - attention_mask, input_ids.shape, inputs_embeds, past_seen_tokens + attention_mask, + input_ids.shape if input_ids is not None else (batch_size, seq_length), + inputs_embeds, + past_seen_tokens, ) else: causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) diff --git a/tests/transformers/tests/models/llama/test_modeling_llama.py b/tests/transformers/tests/models/llama/test_modeling_llama.py index 6b76834a50..2c505b6811 100644 --- a/tests/transformers/tests/models/llama/test_modeling_llama.py +++ b/tests/transformers/tests/models/llama/test_modeling_llama.py @@ -17,10 +17,11 @@ import unittest from parameterized import parameterized -from transformers import LlamaConfig, is_torch_available, set_seed +from transformers import LlamaConfig, is_torch_available from transformers.testing_utils import require_torch, slow from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +from optimum.habana.utils import set_seed from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester