From 361b145edaa644c68e7648504eec1d5df1b2eebe Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Tue, 12 Mar 2024 06:56:07 +0000 Subject: [PATCH 1/3] Fix llama test_model_rope_scaling_0_linear, test_model_rope_scaling_1_dynamic. --- optimum/habana/transformers/modeling_utils.py | 5 +- .../habana/transformers/models/__init__.py | 2 + .../transformers/models/llama/__init__.py | 4 +- .../models/llama/modeling_llama.py | 61 ++++++++++++++----- 4 files changed, 54 insertions(+), 18 deletions(-) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 148f749fd3..cb9dfcc47b 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -40,6 +40,8 @@ GaudiLlamaMLP, GaudiLlamaModel, GaudiLlamaRotaryEmbedding, + GaudiLlamaLinearScalingRotaryEmbedding, + GaudiLlamaDynamicNTKScalingRotaryEmbedding, GaudiMistralForCausalLM, GaudiMixtralForCausalLM, GaudiMptForCausalLM, @@ -269,7 +271,8 @@ def adapt_transformers_to_gaudi(): transformers.models.llama.modeling_llama.LlamaMLP = GaudiLlamaMLP transformers.models.llama.modeling_llama.LlamaDecoderLayer = GaudiLlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = GaudiLlamaRotaryEmbedding - + transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = GaudiLlamaLinearScalingRotaryEmbedding + transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding = GaudiLlamaDynamicNTKScalingRotaryEmbedding transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward # Optimization for falcon generation on Gaudi diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 6880350aa4..b8c911fdac 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -75,6 +75,8 @@ GaudiLlamaMLP, GaudiLlamaModel, GaudiLlamaRotaryEmbedding, + GaudiLlamaLinearScalingRotaryEmbedding, + GaudiLlamaDynamicNTKScalingRotaryEmbedding, gaudi_llama_rmsnorm_forward, ) from .mistral import ( diff --git a/optimum/habana/transformers/models/llama/__init__.py b/optimum/habana/transformers/models/llama/__init__.py index ab46467a9d..6fecf6151c 100644 --- a/optimum/habana/transformers/models/llama/__init__.py +++ b/optimum/habana/transformers/models/llama/__init__.py @@ -5,5 +5,7 @@ GaudiLlamaMLP, GaudiLlamaModel, GaudiLlamaRotaryEmbedding, + GaudiLlamaLinearScalingRotaryEmbedding, + GaudiLlamaDynamicNTKScalingRotaryEmbedding, gaudi_llama_rmsnorm_forward, -) +) \ No newline at end of file diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index a388b3b470..29355bddc4 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -15,6 +15,8 @@ LlamaModel, LlamaRMSNorm, LlamaRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaDynamicNTKScalingRotaryEmbedding, apply_rotary_pos_emb, logger, ) @@ -210,6 +212,39 @@ def forward(self, x, seq_len=None): self._sin_cached[:seq_len].to(dtype=x.dtype), ) +class GaudiLlamaLinearScalingRotaryEmbedding(LlamaLinearScalingRotaryEmbedding, GaudiLlamaRotaryEmbedding): + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self._cos_cached = emb.cos().to(dtype) + self._sin_cached = emb.sin().to(dtype) + + def forward(self, x, seq_len=None): + return GaudiLlamaRotaryEmbedding.forward(self, x, seq_len) + +class GaudiLlamaDynamicNTKScalingRotaryEmbedding(LlamaDynamicNTKScalingRotaryEmbedding, GaudiLlamaRotaryEmbedding): + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self._cos_cached = emb.cos().to(dtype) + self._sin_cached = emb.sin().to(dtype) + + def forward(self, x, seq_len=None): + return GaudiLlamaRotaryEmbedding.forward(self, x, seq_len) class GaudiLlamaAttention(LlamaAttention): def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): @@ -230,13 +265,12 @@ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): self.v_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) def update_sincos_cache(self, seq_len): - if self.config.rope_scaling is None: - # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings - # This helps in avoiding creation of these caches during actual model forward pass and - # reduce memory consumption and improve performance. - if seq_len > self.max_position_embeddings: - self.max_position_embeddings = seq_len - _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) def reorder(self, tensor, beam_idx, dim_a, dim_b): updated = tensor.index_select(0, beam_idx) @@ -306,7 +340,6 @@ def pre_attn_forward( # TODO: update when auto mp params is enabled in DeepSpeed (cf. https://github.com/HabanaAI/DeepSpeed/blob/94309c7b5dfc1a69858f5c9f25737b2f81a332a5/deepspeed/module_inject/replace_module.py#L440) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] if past_key_value is not None: if token_idx is None: @@ -319,11 +352,10 @@ def pre_attn_forward( kv_seq_len = past_key_value[0][-2] else: kv_seq_len = past_key_value[0].shape[-2] - if self.config.rope_scaling is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - else: - cos, sin = self.rotary_emb(value_states, position_ids) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None or reuse_cache: # reuse k, v, self_attention if reuse_cache: @@ -401,7 +433,6 @@ def pre_attn_forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value def attention_all_reduce(self, attn_output): @@ -660,9 +691,7 @@ def forward( # HPU specific mask generation if ignore_cache_position: - causal_mask = _gaudi_prepare_4d_causal_attention_mask( - attention_mask, input_ids.shape, inputs_embeds, past_seen_tokens - ) + causal_mask = _gaudi_prepare_4d_causal_attention_mask(attention_mask, input_ids.shape if input_ids is not None else (batch_size, seq_length), inputs_embeds, past_seen_tokens) else: causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) # embed positions From df2ebe84c68dc7f971860c0e8e9208d481e58b8b Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Tue, 12 Mar 2024 10:20:17 +0000 Subject: [PATCH 2/3] Make style --- optimum/habana/transformers/modeling_utils.py | 8 +++++--- optimum/habana/transformers/models/__init__.py | 4 ++-- .../habana/transformers/models/llama/__init__.py | 4 ++-- .../transformers/models/llama/modeling_llama.py | 16 +++++++++++++--- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index cb9dfcc47b..701128ec1e 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -36,12 +36,12 @@ GaudiGPTNeoXForCausalLM, GaudiLlamaAttention, GaudiLlamaDecoderLayer, + GaudiLlamaDynamicNTKScalingRotaryEmbedding, GaudiLlamaForCausalLM, + GaudiLlamaLinearScalingRotaryEmbedding, GaudiLlamaMLP, GaudiLlamaModel, GaudiLlamaRotaryEmbedding, - GaudiLlamaLinearScalingRotaryEmbedding, - GaudiLlamaDynamicNTKScalingRotaryEmbedding, GaudiMistralForCausalLM, GaudiMixtralForCausalLM, GaudiMptForCausalLM, @@ -272,7 +272,9 @@ def adapt_transformers_to_gaudi(): transformers.models.llama.modeling_llama.LlamaDecoderLayer = GaudiLlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = GaudiLlamaRotaryEmbedding transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = GaudiLlamaLinearScalingRotaryEmbedding - transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding = GaudiLlamaDynamicNTKScalingRotaryEmbedding + transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding = ( + GaudiLlamaDynamicNTKScalingRotaryEmbedding + ) transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward # Optimization for falcon generation on Gaudi diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index b8c911fdac..fff777f72a 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -71,12 +71,12 @@ from .llama import ( GaudiLlamaAttention, GaudiLlamaDecoderLayer, + GaudiLlamaDynamicNTKScalingRotaryEmbedding, GaudiLlamaForCausalLM, + GaudiLlamaLinearScalingRotaryEmbedding, GaudiLlamaMLP, GaudiLlamaModel, GaudiLlamaRotaryEmbedding, - GaudiLlamaLinearScalingRotaryEmbedding, - GaudiLlamaDynamicNTKScalingRotaryEmbedding, gaudi_llama_rmsnorm_forward, ) from .mistral import ( diff --git a/optimum/habana/transformers/models/llama/__init__.py b/optimum/habana/transformers/models/llama/__init__.py index 6fecf6151c..ed7c000f59 100644 --- a/optimum/habana/transformers/models/llama/__init__.py +++ b/optimum/habana/transformers/models/llama/__init__.py @@ -1,11 +1,11 @@ from .modeling_llama import ( GaudiLlamaAttention, GaudiLlamaDecoderLayer, + GaudiLlamaDynamicNTKScalingRotaryEmbedding, GaudiLlamaForCausalLM, + GaudiLlamaLinearScalingRotaryEmbedding, GaudiLlamaMLP, GaudiLlamaModel, GaudiLlamaRotaryEmbedding, - GaudiLlamaLinearScalingRotaryEmbedding, - GaudiLlamaDynamicNTKScalingRotaryEmbedding, gaudi_llama_rmsnorm_forward, ) \ No newline at end of file diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 29355bddc4..503aa879b6 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -10,13 +10,13 @@ from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, + LlamaDynamicNTKScalingRotaryEmbedding, LlamaForCausalLM, + LlamaLinearScalingRotaryEmbedding, LlamaMLP, LlamaModel, LlamaRMSNorm, LlamaRotaryEmbedding, - LlamaLinearScalingRotaryEmbedding, - LlamaDynamicNTKScalingRotaryEmbedding, apply_rotary_pos_emb, logger, ) @@ -212,6 +212,7 @@ def forward(self, x, seq_len=None): self._sin_cached[:seq_len].to(dtype=x.dtype), ) + class GaudiLlamaLinearScalingRotaryEmbedding(LlamaLinearScalingRotaryEmbedding, GaudiLlamaRotaryEmbedding): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len @@ -227,6 +228,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): def forward(self, x, seq_len=None): return GaudiLlamaRotaryEmbedding.forward(self, x, seq_len) + class GaudiLlamaDynamicNTKScalingRotaryEmbedding(LlamaDynamicNTKScalingRotaryEmbedding, GaudiLlamaRotaryEmbedding): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len @@ -246,6 +248,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): def forward(self, x, seq_len=None): return GaudiLlamaRotaryEmbedding.forward(self, x, seq_len) + class GaudiLlamaAttention(LlamaAttention): def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) @@ -340,6 +343,7 @@ def pre_attn_forward( # TODO: update when auto mp params is enabled in DeepSpeed (cf. https://github.com/HabanaAI/DeepSpeed/blob/94309c7b5dfc1a69858f5c9f25737b2f81a332a5/deepspeed/module_inject/replace_module.py#L440) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] if past_key_value is not None: if token_idx is None: @@ -433,6 +437,7 @@ def pre_attn_forward( if not output_attentions: attn_weights = None + return attn_output, attn_weights, past_key_value def attention_all_reduce(self, attn_output): @@ -691,7 +696,12 @@ def forward( # HPU specific mask generation if ignore_cache_position: - causal_mask = _gaudi_prepare_4d_causal_attention_mask(attention_mask, input_ids.shape if input_ids is not None else (batch_size, seq_length), inputs_embeds, past_seen_tokens) + causal_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + input_ids.shape if input_ids is not None else (batch_size, seq_length), + inputs_embeds, + past_seen_tokens, + ) else: causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) # embed positions From 1f8b435f5fc71e63b3f6fe4129785f46a9728954 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Tue, 12 Mar 2024 16:01:58 +0000 Subject: [PATCH 3/3] Fix dynamic test --- .../transformers/models/llama/__init__.py | 2 +- .../models/llama/modeling_llama.py | 47 +++++++++++-------- .../tests/models/llama/test_modeling_llama.py | 3 +- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/optimum/habana/transformers/models/llama/__init__.py b/optimum/habana/transformers/models/llama/__init__.py index ed7c000f59..20703ffd09 100644 --- a/optimum/habana/transformers/models/llama/__init__.py +++ b/optimum/habana/transformers/models/llama/__init__.py @@ -8,4 +8,4 @@ GaudiLlamaModel, GaudiLlamaRotaryEmbedding, gaudi_llama_rmsnorm_forward, -) \ No newline at end of file +) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 503aa879b6..922eef50a1 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -10,13 +10,10 @@ from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, - LlamaDynamicNTKScalingRotaryEmbedding, LlamaForCausalLM, - LlamaLinearScalingRotaryEmbedding, LlamaMLP, LlamaModel, LlamaRMSNorm, - LlamaRotaryEmbedding, apply_rotary_pos_emb, logger, ) @@ -191,7 +188,22 @@ def forward(self, cur, dim, idx): return update(self.cache, cur, dim, idx, self.inp_seq_len) -class GaudiLlamaRotaryEmbedding(LlamaRotaryEmbedding): +class GaudiLlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) @@ -199,8 +211,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self._cos_cached = emb.cos().to(dtype) - self._sin_cached = emb.sin().to(dtype) + self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -213,7 +225,7 @@ def forward(self, x, seq_len=None): ) -class GaudiLlamaLinearScalingRotaryEmbedding(LlamaLinearScalingRotaryEmbedding, GaudiLlamaRotaryEmbedding): +class GaudiLlamaLinearScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) @@ -222,31 +234,28 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self._cos_cached = emb.cos().to(dtype) - self._sin_cached = emb.sin().to(dtype) + self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - return GaudiLlamaRotaryEmbedding.forward(self, x, seq_len) - -class GaudiLlamaDynamicNTKScalingRotaryEmbedding(LlamaDynamicNTKScalingRotaryEmbedding, GaudiLlamaRotaryEmbedding): +class GaudiLlamaDynamicNTKScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len + if seq_len > self.max_position_embeddings: base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) - self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self._cos_cached = emb.cos().to(dtype) - self._sin_cached = emb.sin().to(dtype) - - def forward(self, x, seq_len=None): - return GaudiLlamaRotaryEmbedding.forward(self, x, seq_len) + self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) class GaudiLlamaAttention(LlamaAttention): diff --git a/tests/transformers/tests/models/llama/test_modeling_llama.py b/tests/transformers/tests/models/llama/test_modeling_llama.py index 6b76834a50..2c505b6811 100644 --- a/tests/transformers/tests/models/llama/test_modeling_llama.py +++ b/tests/transformers/tests/models/llama/test_modeling_llama.py @@ -17,10 +17,11 @@ import unittest from parameterized import parameterized -from transformers import LlamaConfig, is_torch_available, set_seed +from transformers import LlamaConfig, is_torch_available from transformers.testing_utils import require_torch, slow from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +from optimum.habana.utils import set_seed from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester