Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
GaudiGPTNeoXForCausalLM,
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
GaudiLlamaDynamicNTKScalingRotaryEmbedding,
GaudiLlamaForCausalLM,
GaudiLlamaLinearScalingRotaryEmbedding,
GaudiLlamaMLP,
GaudiLlamaModel,
GaudiLlamaRotaryEmbedding,
Expand Down Expand Up @@ -269,7 +271,10 @@ def adapt_transformers_to_gaudi():
transformers.models.llama.modeling_llama.LlamaMLP = GaudiLlamaMLP
transformers.models.llama.modeling_llama.LlamaDecoderLayer = GaudiLlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = GaudiLlamaRotaryEmbedding

transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = GaudiLlamaLinearScalingRotaryEmbedding
transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding = (
GaudiLlamaDynamicNTKScalingRotaryEmbedding
)
transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward

# Optimization for falcon generation on Gaudi
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@
from .llama import (
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
GaudiLlamaDynamicNTKScalingRotaryEmbedding,
GaudiLlamaForCausalLM,
GaudiLlamaLinearScalingRotaryEmbedding,
GaudiLlamaMLP,
GaudiLlamaModel,
GaudiLlamaRotaryEmbedding,
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/models/llama/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .modeling_llama import (
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
GaudiLlamaDynamicNTKScalingRotaryEmbedding,
GaudiLlamaForCausalLM,
GaudiLlamaLinearScalingRotaryEmbedding,
GaudiLlamaMLP,
GaudiLlamaModel,
GaudiLlamaRotaryEmbedding,
Expand Down
80 changes: 64 additions & 16 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
logger,
)
Expand Down Expand Up @@ -189,16 +188,31 @@ def forward(self, cur, dim, idx):
return update(self.cache, cur, dim, idx, self.inp_seq_len)


class GaudiLlamaRotaryEmbedding(LlamaRotaryEmbedding):
class GaudiLlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()

self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos().to(dtype)
self._sin_cached = emb.sin().to(dtype)
self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
Expand All @@ -211,6 +225,39 @@ def forward(self, x, seq_len=None):
)


class GaudiLlamaLinearScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding):
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False)


class GaudiLlamaDynamicNTKScalingRotaryEmbedding(GaudiLlamaRotaryEmbedding):
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len

if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False)


class GaudiLlamaAttention(LlamaAttention):
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
Expand All @@ -230,13 +277,12 @@ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
self.v_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape)

def update_sincos_cache(self, seq_len):
if self.config.rope_scaling is None:
# Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings
# This helps in avoiding creation of these caches during actual model forward pass and
# reduce memory consumption and improve performance.
if seq_len > self.max_position_embeddings:
self.max_position_embeddings = seq_len
_, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len)
# Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings
# This helps in avoiding creation of these caches during actual model forward pass and
# reduce memory consumption and improve performance.
if seq_len > self.max_position_embeddings:
self.max_position_embeddings = seq_len
_, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len)

def reorder(self, tensor, beam_idx, dim_a, dim_b):
updated = tensor.index_select(0, beam_idx)
Expand Down Expand Up @@ -319,11 +365,10 @@ def pre_attn_forward(
kv_seq_len = past_key_value[0][-2]
else:
kv_seq_len = past_key_value[0].shape[-2]
if self.config.rope_scaling is None:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
else:
cos, sin = self.rotary_emb(value_states, position_ids)

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None or reuse_cache:
# reuse k, v, self_attention
if reuse_cache:
Expand Down Expand Up @@ -661,7 +706,10 @@ def forward(
# HPU specific mask generation
if ignore_cache_position:
causal_mask = _gaudi_prepare_4d_causal_attention_mask(
attention_mask, input_ids.shape, inputs_embeds, past_seen_tokens
attention_mask,
input_ids.shape if input_ids is not None else (batch_size, seq_length),
inputs_embeds,
past_seen_tokens,
)
else:
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
Expand Down
3 changes: 2 additions & 1 deletion tests/transformers/tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import unittest

from parameterized import parameterized
from transformers import LlamaConfig, is_torch_available, set_seed
from transformers import LlamaConfig, is_torch_available
from transformers.testing_utils import require_torch, slow

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
from optimum.habana.utils import set_seed

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
Expand Down