diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 88f167a9f9..3d5f363281 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -185,6 +185,8 @@ gaudi_esm_for_protein_folding_forward, gaudi_esmfolding_trunk_forward, gaudi_falcon_linear_forward, + gaudi_FalconMambaForCausalLM_prepare_inputs_for_generation, + gaudi_FalconMambaModel_forward, gaudi_generate_speech, gaudi_get_extended_attention_mask, gaudi_gpt2_forward, @@ -648,6 +650,11 @@ def adapt_transformers_to_gaudi(): transformers.models.mamba.modeling_mamba.MambaForCausalLM._update_model_kwargs_for_generation = ( gaudi_MambaForCausalLM_update_model_kwargs_for_generation ) + transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaForCausalLM.prepare_inputs_for_generation = ( + gaudi_FalconMambaForCausalLM_prepare_inputs_for_generation + ) + transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaModel.forward = gaudi_FalconMambaModel_forward + transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaRMSNorm.forward = gaudi_llama_rmsnorm_forward # Optimization for Whisper on Gaudi transformers.models.whisper.modeling_whisper.WhisperSdpaAttention = GaudiWhisperSdpaAttention diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 5e9328df0e..28df902edd 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -70,6 +70,10 @@ GaudiFalconModel, gaudi_falcon_linear_forward, ) +from .falcon_mamba import ( + gaudi_FalconMambaForCausalLM_prepare_inputs_for_generation, + gaudi_FalconMambaModel_forward, +) from .gemma import ( GaudiGemmaAttention, GaudiGemmaDecoderLayer, diff --git a/optimum/habana/transformers/models/falcon_mamba/__init__.py b/optimum/habana/transformers/models/falcon_mamba/__init__.py new file mode 100644 index 0000000000..efbcaaa072 --- /dev/null +++ b/optimum/habana/transformers/models/falcon_mamba/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_falcon_mamba import ( + gaudi_FalconMambaForCausalLM_prepare_inputs_for_generation, + gaudi_FalconMambaModel_forward, +) diff --git a/optimum/habana/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/optimum/habana/transformers/models/falcon_mamba/modeling_falcon_mamba.py new file mode 100644 index 0000000000..5d618fac91 --- /dev/null +++ b/optimum/habana/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch FALCONMAMBA model.""" + +from typing import Optional, Tuple, Union + +import habana_frameworks.torch.core as htcore +import torch +from transformers.cache_utils import MambaCache +from transformers.models.falcon_mamba.modeling_falcon_mamba import FalconMambaOutput +from transformers.utils import ( + logging, +) + + +logger = logging.get_logger(__name__) + +""" +Copys from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py#L635 +The only differences are: +- Use the mark_step function to reduce the graph compiling time. +""" + + +def gaudi_FalconMambaModel_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + cache_params: Optional[MambaCache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + lazy_mode: Optional[bool] = True, +) -> Union[Tuple, FalconMambaOutput]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if use_cache: + if cache_params is None: + cache_params = MambaCache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) + elif cache_position is None: + # cases when we do manual forward instead of using `model.generate` which will initiate + # `cache_position` and makes sure it is not None, throw error here instead of doing some + # hack to conjecture the current cache position + raise ValueError( + "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, " + "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will " + "be initialized for you automatically" + ) + else: + cache_params = None + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + if lazy_mode: + htcore.mark_step() + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.norm_f(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return FalconMambaOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + + +""" +Copys from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py#L762 +The only differences are: +- Use the torch.index_select function to replace the slicing operation of Line 51 +""" + + +def gaudi_FalconMambaForCausalLM_prepare_inputs_for_generation( + self, + input_ids, + inputs_embeds=None, + use_cache=None, + cache_params: Optional[MambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + **kwargs, +): + if use_cache: + # `cache_position` should have been initialized in `generate` + if cache_position is None: + raise ValueError( + "`cache_position` should not be None as it should have been initialized in " + "`model.generate`, you are responsible for passing in a valid `cache_position` if " + "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" + ) + if cache_position[0] > 0: + # input_ids = input_ids[:, -1].unsqueeze(-1) + idx = torch.tensor([input_ids.size(1) - 1], device=input_ids.device) + input_ids = torch.index_select(input_ids, 1, idx) + + if attention_mask is not None: + attention_mask = None + + else: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) + + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "cache_params": cache_params, + "use_cache": use_cache, + "cache_position": cache_position, + "attention_mask": attention_mask, + } + ) + return model_inputs diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index e411f64789..6be0482f7b 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -56,6 +56,7 @@ ("EleutherAI/gpt-neo-2.7B", 1, False, 257.2476416844122, False), ("facebook/xglm-1.7B", 1, False, 357.46365062825083, False), ("CohereForAI/c4ai-command-r-v01", 1, False, 29.50315234651154, False), + ("tiiuae/falcon-mamba-7b", 1, False, 47.1464839567739), ("openbmb/MiniCPM3-4B", 1, False, 65.116, False), ], "fp8": [