diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py index 03b8ea16b3..54cec6a16c 100644 --- a/optimum/habana/transformers/models/mllama/modeling_mllama.py +++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py @@ -2,7 +2,7 @@ # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. +# You may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """PyTorch Mllama model.""" import math @@ -26,8 +27,15 @@ from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.models.mllama.configuration_mllama import ( + MllamaConfig, + MllamaTextConfig, +) from transformers.models.mllama.modeling_mllama import ( MllamaCrossAttentionDecoderLayer, MllamaForCausalLM, @@ -35,6 +43,7 @@ MllamaSelfAttentionDecoderLayer, MllamaTextCrossAttention, MllamaTextModel, + MllamaTextRMSNorm, MllamaTextSelfAttention, MllamaVisionAttention, MllamaVisionConfig, @@ -46,13 +55,16 @@ apply_rotary_pos_emb, repeat_kv, ) -from transformers.utils import ( - logging, -) +from transformers.utils import logging + +from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask -from ...modeling_attn_mask_utils import ( - _gaudi_prepare_4d_causal_attention_mask, -) + +try: + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm +except ImportError: + print("Not using HPU fused kernel for RMSNorm") + FusedRMSNorm = None logger = logging.get_logger(__name__) @@ -64,6 +76,32 @@ FusedSDPA = None +class GaudiMllamaTextRMSNorm(MllamaTextRMSNorm): + def __init__(self, hidden_size, eps=1e-6): + """ + MllamaTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__(hidden_size, eps) + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + """Copied from MllamaTextRMSNorm::forward https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/mllama/modeling_mllama.py#L475. The only differences are: + - Using FusedRMSNorm""" + orig_dtype = hidden_states.dtype + if FusedRMSNorm is not None: + hidden_states = FusedRMSNorm.apply(hidden_states.float(), self.weight.float(), self.variance_epsilon) + return hidden_states.to(orig_dtype) + else: + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(orig_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + class ModuleFusedSDPA(torch.nn.Module): def __init__(self, fusedSDPA): super().__init__() @@ -479,6 +517,7 @@ class GaudiMllamaSelfAttentionDecoderLayer(MllamaSelfAttentionDecoderLayer): def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: super(GaudiMllamaSelfAttentionDecoderLayer, self).__init__(config, layer_idx) self.self_attn = GaudiMllamaTextSelfAttention(config, layer_idx=layer_idx) + self.input_layernorm = GaudiMllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self,