diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index f12e7d3540..26607a13ac 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -90,6 +90,7 @@ gaudi_gpt_neox_attention_forward, gaudi_gpt_neox_layer_forward, gaudi_gpt_neox_model_forward, + gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache, gaudi_gptj_block_forward, gaudi_gptj_model_forward, gaudi_invert_attention_mask, @@ -262,6 +263,9 @@ def adapt_transformers_to_gaudi(): transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXModel.forward = gaudi_gpt_neox_model_forward transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward = gaudi_gpt_neox_layer_forward transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention.forward = gaudi_gpt_neox_attention_forward + transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding._set_cos_sin_cache = ( + gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache + ) # Optimization for llama generation on Gaudi transformers.models.llama.modeling_llama.LlamaForCausalLM = GaudiLlamaForCausalLM diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index aa04c6f3c3..cec8590003 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -61,6 +61,7 @@ gaudi_gpt_neox_attention_forward, gaudi_gpt_neox_layer_forward, gaudi_gpt_neox_model_forward, + gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache, ) from .gptj import ( GaudiGPTJAttention, diff --git a/optimum/habana/transformers/models/gpt_neox/__init__.py b/optimum/habana/transformers/models/gpt_neox/__init__.py index d3f6ab124d..cceb114b82 100644 --- a/optimum/habana/transformers/models/gpt_neox/__init__.py +++ b/optimum/habana/transformers/models/gpt_neox/__init__.py @@ -3,4 +3,5 @@ gaudi_gpt_neox_attention_forward, gaudi_gpt_neox_layer_forward, gaudi_gpt_neox_model_forward, + gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache, ) diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py index 9e2f9aaae0..08f3433377 100644 --- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -31,6 +31,11 @@ def gaudi_gpt_neox_attention_forward( - add new args token_idx - optimize KV cache """ + # Workaround till FusedRoPE is fixed + global FusedRoPE + if self.training and FusedRoPE is not None: + FusedRoPE = None + has_layer_past = layer_past is not None # Compute QKV @@ -404,6 +409,17 @@ def prepare_inputs_for_generation( return model_inputs +def gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() + + def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: return FusedRoPE.apply( diff --git a/tests/baselines/gpt_neox_20b.json b/tests/baselines/gpt_neox_20b.json index b3c8114d1d..165debd4ca 100644 --- a/tests/baselines/gpt_neox_20b.json +++ b/tests/baselines/gpt_neox_20b.json @@ -7,9 +7,9 @@ "deepspeed": { "learning_rate": 5e-5, "train_batch_size": 2, - "perplexity": 8.787531864839819, - "train_runtime": 670.5209, - "train_samples_per_second": 8.485, + "perplexity": 8.0545, + "train_runtime": 745.7237, + "train_samples_per_second": 7.242, "extra_arguments": [ "--dataset_config_name wikitext-2-raw-v1", "--gradient_checkpointing",