Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions optimum/habana/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
Expand All @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""PyTorch Mllama model."""

import math
Expand All @@ -26,15 +27,23 @@
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.models.mllama.configuration_mllama import (
MllamaConfig,
MllamaTextConfig,
)
from transformers.models.mllama.modeling_mllama import (
MllamaCrossAttentionDecoderLayer,
MllamaForCausalLM,
MllamaForConditionalGeneration,
MllamaSelfAttentionDecoderLayer,
MllamaTextCrossAttention,
MllamaTextModel,
MllamaTextRMSNorm,
MllamaTextSelfAttention,
MllamaVisionAttention,
MllamaVisionConfig,
Expand All @@ -46,13 +55,16 @@
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import (
logging,
)
from transformers.utils import logging

from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask

from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
)

try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm
except ImportError:
print("Not using HPU fused kernel for RMSNorm")
FusedRMSNorm = None


logger = logging.get_logger(__name__)
Expand All @@ -64,6 +76,32 @@
FusedSDPA = None


class GaudiMllamaTextRMSNorm(MllamaTextRMSNorm):
def __init__(self, hidden_size, eps=1e-6):
"""
MllamaTextRMSNorm is equivalent to T5LayerNorm
"""
super().__init__(hidden_size, eps)
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
"""Copied from MllamaTextRMSNorm::forward https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/mllama/modeling_mllama.py#L475. The only differences are:
- Using FusedRMSNorm"""
orig_dtype = hidden_states.dtype
if FusedRMSNorm is not None:
hidden_states = FusedRMSNorm.apply(hidden_states.float(), self.weight.float(), self.variance_epsilon)
return hidden_states.to(orig_dtype)
else:
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(orig_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
super().__init__()
Expand Down Expand Up @@ -479,6 +517,7 @@ class GaudiMllamaSelfAttentionDecoderLayer(MllamaSelfAttentionDecoderLayer):
def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None:
super(GaudiMllamaSelfAttentionDecoderLayer, self).__init__(config, layer_idx)
self.self_attn = GaudiMllamaTextSelfAttention(config, layer_idx=layer_idx)
self.input_layernorm = GaudiMllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
Expand Down