diff --git a/optimum/habana/transformers/models/baichuan/modeling_baichuan.py b/optimum/habana/transformers/models/baichuan/modeling_baichuan.py index 414fd02ba3..2fcf6bec0a 100644 --- a/optimum/habana/transformers/models/baichuan/modeling_baichuan.py +++ b/optimum/habana/transformers/models/baichuan/modeling_baichuan.py @@ -43,6 +43,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.utils import logging +from ...generation.utils import GaudiGenerationMixin from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask from .configuration_baichuan import BaichuanConfig from .generation_utils import TextIterStreamer, build_chat_input @@ -1163,7 +1164,7 @@ def no_init_weights(_enable=True): _init_weights = old_init_weights -class BaichuanForCausalLM(BaichuanPreTrainedModel): +class BaichuanForCausalLM(BaichuanPreTrainedModel, GaudiGenerationMixin): def __init__(self, config, *model_args, **model_kwargs): super().__init__(config, *model_args, **model_kwargs) self.model = BaichuanModel(config) diff --git a/optimum/habana/transformers/models/chatglm/modeling_chatglm.py b/optimum/habana/transformers/models/chatglm/modeling_chatglm.py index 6f71fef42e..9548cc749b 100644 --- a/optimum/habana/transformers/models/chatglm/modeling_chatglm.py +++ b/optimum/habana/transformers/models/chatglm/modeling_chatglm.py @@ -43,6 +43,7 @@ from transformers.utils import logging from ....utils import warn0 +from ...generation.utils import GaudiGenerationMixin from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask from .configuration_chatglm import ChatGLMConfig @@ -1347,7 +1348,7 @@ def forward( ) -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel, GaudiGenerationMixin): def __init__(self, config: ChatGLMConfig, empty_init=False, device=None): super().__init__(config) diff --git a/optimum/habana/transformers/models/gemma/modeling_gemma.py b/optimum/habana/transformers/models/gemma/modeling_gemma.py index b853090c92..21dce5bd10 100755 --- a/optimum/habana/transformers/models/gemma/modeling_gemma.py +++ b/optimum/habana/transformers/models/gemma/modeling_gemma.py @@ -661,7 +661,7 @@ def forward( ): htcore.mark_step() - hidden_states = decoder_layer( + layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -678,8 +678,10 @@ def forward( **kwargs, ) + hidden_states = layer_outputs[0] + if use_cache: - next_decoder_cache += (hidden_states[1],) + next_decoder_cache += (layer_outputs[1],) hidden_states = self.norm(hidden_states) diff --git a/optimum/habana/transformers/models/minicpm/modeling_minicpm.py b/optimum/habana/transformers/models/minicpm/modeling_minicpm.py index 7aee2ebe84..88f11ef59a 100644 --- a/optimum/habana/transformers/models/minicpm/modeling_minicpm.py +++ b/optimum/habana/transformers/models/minicpm/modeling_minicpm.py @@ -54,6 +54,7 @@ from transformers.utils.import_utils import is_torch_fx_available from ....utils import warn0 +from ...generation.utils import GaudiGenerationMixin from .configuration_minicpm import MiniCPM3Config @@ -505,8 +506,7 @@ def forward( value_states = past_value_states.index_add( -2, token_idx - 1, value_states - torch.index_select(past_value_states, -2, token_idx - 1) ) - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states + past_key_value.update(key_states, value_states, self.layer_idx) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale @@ -644,8 +644,7 @@ def forward( value_states = past_value_states.index_add( -2, token_idx - 1, value_states - torch.index_select(past_value_states, -2, token_idx - 1) ) - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states + past_key_value.update(key_states, value_states, self.layer_idx) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -854,7 +853,7 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - usable_length = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + usable_length = past_key_value.get_seq_length(self.layer_idx) if token_idx is None: kv_seq_len += usable_length elif usable_length > 0: @@ -1023,7 +1022,9 @@ def forward( "The bare MiniCPM Model outputting raw hidden-states without any specific head on top.", MINICPM_START_DOCSTRING, ) -class MiniCPM3PreTrainedModel(PreTrainedModel): +class MiniCPM3PreTrainedModel( + PreTrainedModel, +): config_class = MiniCPM3Config base_model_prefix = "model" supports_gradient_checkpointing = True @@ -1200,7 +1201,7 @@ def forward( use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + past_key_values_length = past_key_values.get_seq_length() if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -1292,7 +1293,7 @@ def forward( ) -class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel): +class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel, GaudiGenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index efceaeaab0..356d573362 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -385,7 +385,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, _, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, diff --git a/optimum/habana/transformers/models/persimmon/modeling_persimmon.py b/optimum/habana/transformers/models/persimmon/modeling_persimmon.py index 4209d0c94c..f92ee91cd0 100644 --- a/optimum/habana/transformers/models/persimmon/modeling_persimmon.py +++ b/optimum/habana/transformers/models/persimmon/modeling_persimmon.py @@ -73,11 +73,11 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - if token_idx is not None and past_key_value.get_usable_length(kv_seq_len, self.layer_idx) > 0: + if token_idx is not None and past_key_value.get_seq_length(kv_seq_len, self.layer_idx) > 0: # When token_idx is used, static seq len = (input token len + max output token len) - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(kv_seq_len, self.layer_idx) else: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_seq_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -98,14 +98,16 @@ def forward( if past_key_value is not None: if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) - past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + if ( + 0 <= self.layer_idx < len(past_key_value) + and past_key_value.layers[self.layer_idx].keys is not None + ): + past_key_value.layers[self.layer_idx].keys.index_copy_(2, token_idx - 1, key_states) + past_key_value.layers[self.layer_idx].values.index_copy_(2, token_idx - 1, value_states) + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + past_key_value.update(key_states, value_states, self.layer_idx) else: # Specific to RoPE models with partial rotation cache_kwargs = { diff --git a/optimum/habana/transformers/models/stablelm/modeling_stablelm.py b/optimum/habana/transformers/models/stablelm/modeling_stablelm.py index c659b24dd9..9596865f3d 100644 --- a/optimum/habana/transformers/models/stablelm/modeling_stablelm.py +++ b/optimum/habana/transformers/models/stablelm/modeling_stablelm.py @@ -71,11 +71,11 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - if token_idx is not None and past_key_value.get_usable_length(kv_seq_len, self.layer_idx) > 0: + if token_idx is not None and past_key_value.get_seq_length(self.layer_idx) > 0: # When token_idx is used, static seq len = (input token len + max output token len) - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) else: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -96,14 +96,16 @@ def forward( if past_key_value is not None: if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) - past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + if ( + 0 <= self.layer_idx < len(past_key_value) + and past_key_value.layers[self.layer_idx].keys is not None + ): + past_key_value.layers[self.layer_idx].keys.index_copy_(2, token_idx - 1, key_states) + past_key_value.layers[self.layer_idx].values.index_copy_(2, token_idx - 1, value_states) + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + past_key_value.update(key_states, value_states, self.layer_idx) else: # Specific to RoPE models with partial rotation cache_kwargs = { diff --git a/tests/baselines/fixture/tests/test_text_generation_example.json b/tests/baselines/fixture/tests/test_text_generation_example.json index 4dcc73e6cf..8bcdd37231 100644 --- a/tests/baselines/fixture/tests/test_text_generation_example.json +++ b/tests/baselines/fixture/tests/test_text_generation_example.json @@ -305,7 +305,7 @@ "throughput": 109.70751574382221 }, "gaudi3": { - "output": "DeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch, and it is compatible with the existing PyTorch ecosystem. DeepSpeed is designed to be easy to use, and it provides a number of features that make it easy to train large-scale models. DeepSpeed is designed to be scalable, and it can be used to train models on a single machine or on a cluster of machines. DeepSpeed is designed to be efficient,", + "output": "DeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch and is compatible with existing PyTorch code. DeepSpeed is open source and available on GitHub.\n\nDeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch and is compatible with existing PyTorch code. DeepSpeed is open source and available on GitHub.\n\n

What is", "throughput": 135.97272017864475 } }, @@ -415,7 +415,7 @@ "throughput": 134.94827207337997 }, "gaudi3": { - "output": "DeepSpeed is a machine learning framework that accelerates training of large models on a single machine or distributed systems. It is designed to be compatible with PyTorch and TensorFlow, and can be used to train models on a single machine or on a distributed system.\n\nDeepSpeed is a machine learning framework that accelerates training of large models on a single machine or distributed systems. It is designed to be compatible with PyTorch and TensorFlow, and can be used to train models on a single machine or on a distributed system", + "output": "DeepSpeed is a machine learning framework that accelerates training and inference of deep learning models. It is designed to be flexible and easy to use, with a focus on performance and scalability. DeepSpeed is built on top of PyTorch, and it provides a set of tools and libraries that can be used to optimize the training and inference of deep learning models.\n\nDeepSpeed is designed to be used with a variety of hardware platforms, including GPUs, TPUs, and CPUs. It provides a", "throughput": 160.48685620965531 } }, @@ -425,7 +425,7 @@ "throughput": 71.29570003665306 }, "gaudi3": { - "output": "DeepSpeed is a machine learning framework that enables training of large models on a single machine with multiple GPUs. It is designed to be easy to use and efficient, and it supports a wide range of models and tasks.\n\nDeepSpeed is a deep learning framework that enables training of large models on a single machine with multiple GPUs. It is designed to be easy to use and efficient, and it supports a wide range of models and tasks.\n\nDeepSpeed is a deep learning framework that enables training of large models on a", + "output": "DeepSpeed is a machine learning framework that enables training of large models on a single machine with a single GPU. It is designed to be easy to use and efficient, and it can be used to train models on a variety of tasks.\n\nThe latest DeepSpeed for PC has come up with a few updates that are better than the previous version. Want to know those? Here are they:\n\n## DeepSpeed Andorid App Summary\n\nDeepSpeed has developed the DeepSpeed for Android. You can find it under the", "throughput": 81.6817273229847 } },