diff --git a/optimum/habana/transformers/modeling_attn_mask_utils.py b/optimum/habana/transformers/modeling_attn_mask_utils.py new file mode 100755 index 0000000000..1dc452a3f7 --- /dev/null +++ b/optimum/habana/transformers/modeling_attn_mask_utils.py @@ -0,0 +1,505 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + + +@dataclass +class AttentionMaskConverter: + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Examples: + + ```python + >>> import torch + >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + >>> converter = AttentionMaskConverter(True) + >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) + tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]]) + ``` + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + is_causal: bool + sliding_window: int + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + if self.sliding_window is not None and self.sliding_window <= 0: + raise ValueError( + f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" + ) + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype, + device: Union[torch.device, "str"] = "cpu", + ) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + dtype: torch.dtype, + key_value_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError( + "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." + ) + + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + + if causal_4d_mask is not None: + expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min) + + # expanded_attn_mask + causal_4d_mask can cause some overflow + expanded_4d_mask = expanded_attn_mask + + return expanded_4d_mask + + @staticmethod + def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + #context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) + # Replace triu with below + row_indices = torch.arange(mask.size(0), device=mask.device).view(-1, 1) # Reshape to column vector + col_indices = torch.arange(mask.size(1), device=mask.device) + context_mask = 1 - (col_indices >= row_indices + diagonal).int().expand_as(mask) # Expand to match mask shape + + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + @staticmethod + def _unmask_unattended( + expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float] + ): + # fmt: off + """ + Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when + using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + Details: https://github.com/pytorch/pytorch/issues/110213 + + `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. + `attention_mask` is [bsz, src_seq_len]. + + The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. + + For example, if `attention_mask` is + ``` + [[0, 0, 1], + [1, 1, 1], + [0, 1, 1]] + ``` + and `expanded_mask` is (e.g. here left-padding case) + ``` + [[[[0, 0, 0], + [0, 0, 0], + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[0, 0, 0], + [0, 1, 0], + [0, 1, 1]]]] + ``` + then the modified `expanded_mask` will be + ``` + [[[[1, 1, 1], <-- modified + [1, 1, 1], <-- modified + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[1, 1, 1], <-- modified + [0, 1, 0], + [0, 1, 1]]]] + ``` + """ + # fmt: on + + # Get the index of the first non-zero value for every sample in the batch. + # In the above example, indices = [[2], [0], [1]]] + tmp = torch.arange(attention_mask.shape[1], 0, -1) + indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True) + + # Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the + # expanded mask will be completely unattended. + left_masked_rows = torch.where(indices > 0)[0] + + if left_masked_rows.shape[0] == 0: + return expanded_mask + indices = indices[left_masked_rows] + + max_len = torch.max(indices) + range_tensor = torch.arange(max_len).unsqueeze(0) + range_tensor = range_tensor.repeat(indices.size(0), 1) + + # Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above. + range_tensor[range_tensor >= indices] = 0 + + # TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case + if expanded_mask.dim() == 4: + num_masks = expanded_mask.shape[1] + if num_masks == 1: + # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len] + mask_slice = (left_masked_rows[:, None], 0, range_tensor) + else: + # Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len] + mask_slice = ( + left_masked_rows[:, None, None], + torch.arange(num_masks)[None, :, None], + range_tensor[:, None, :], + ) + else: + # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len] + mask_slice = (left_masked_rows[:, None], range_tensor) + + expanded_mask[mask_slice] = unmasked_value + + return expanded_mask + + +def _prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + attention_mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + inputs_embeds (`torch.Tensor`): + The embedded inputs as a torch Tensor. + past_key_values_length (`int`): + The length of the key value cache. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + + # 4d mask is passed through the layers + if attention_mask is not None and len(attention_mask.shape) == 2: + attention_mask = attn_mask_converter.to_4d( + attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype + ) + elif attention_mask is not None and len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + else: + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + + return attention_mask + + +# Adapted from _prepare_4d_causal_attention_mask +def _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. + + In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and + `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + batch_size, query_length = input_shape + + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. + # TODO: Fix this as well when using torchdynamo with fullgraph=True. + is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) + + if attention_mask is not None: + # 4d mask is passed through + if len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype) + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + return attention_mask + + elif not is_tracing:# and torch.all(attention_mask == 1): + if query_length == 1: + # For query_length == 1, causal attention and bi-directional attention are the same. + attention_mask = None + elif key_value_length == query_length: + attention_mask = None + else: + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + pass + elif query_length > 1 and key_value_length != query_length: + # See the comment above (https://github.com/pytorch/pytorch/issues/108108). + # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`. + attention_mask = True + elif is_tracing: + raise ValueError( + 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.' + ) + + if attention_mask is None: + expanded_4d_mask = None + elif attention_mask is True: + expanded_4d_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + else: + expanded_4d_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + dtype=inputs_embeds.dtype, + key_value_length=key_value_length, + ) + + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + # + # This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent + # controlflow that can not be captured properly. + # TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case. + if query_length > 1 and not is_tracing: + expanded_4d_mask = AttentionMaskConverter._unmask_unattended( + expanded_4d_mask, attention_mask, unmasked_value=0.0 + ) + + return expanded_4d_mask + + +def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + batch_size, key_value_length = mask.shape + tgt_len = tgt_len if tgt_len is not None else key_value_length + + # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. + # TODO: Fix this as well when using torchdynamo with fullgraph=True. + is_tracing = torch.jit.is_tracing() + + if torch.all(mask == 1): + if is_tracing: + pass + elif tgt_len == 1: + # For query_length == 1, causal attention and bi-directional attention are the same. + return None + elif key_value_length == tgt_len: + return None + else: + # Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + else: + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _create_4d_causal_attention_mask( + input_shape: Union[torch.Size, Tuple, List], + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, +) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` + + Args: + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + device (`int`): + The torch device the created mask shall have. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = past_key_values_length + input_shape[-1] + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device + ) + + return attention_mask diff --git a/optimum/habana/transformers/models/codegen/modeling_codegen.py b/optimum/habana/transformers/models/codegen/modeling_codegen.py index 871befd3ed..e70873c1f3 100644 --- a/optimum/habana/transformers/models/codegen/modeling_codegen.py +++ b/optimum/habana/transformers/models/codegen/modeling_codegen.py @@ -87,7 +87,8 @@ def forward( if use_cache is True: # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32. # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38 - present = (key.to(hidden_states.dtype), value) + #present = (key.to(hidden_states.dtype), value) + present = (key, value) else: present = None diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index f9dcb6300c..06974ecf6b 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -20,7 +20,7 @@ SDPContext = False try: - from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV1 as FusedRoPE + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE except ImportError: print("Not using HPU fused kernel for apply_rotary_pos_emb") FusedRoPE = None @@ -29,7 +29,7 @@ import habana_frameworks.torch.core as htcore from torch.nn import CrossEntropyLoss from torch.nn import functional as F -from transformers.modeling_attn_mask_utils import ( +from optimum.habana.transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, @@ -54,8 +54,8 @@ def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: # TODO: remove `.clone()` when SynapseAI v1.15 is released - return FusedRoPE.apply(q, cos.clone(), sin.clone(), position_ids), FusedRoPE.apply( - k, cos.clone(), sin.clone(), position_ids + return FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids), FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids ) else: return apply_rotary_pos_emb(q, k, cos, sin, position_ids) diff --git a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 8164f5b5f9..c1b59b2c26 100644 --- a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -399,7 +399,7 @@ def prepare_inputs_for_generation( position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: if token_idx is not None: - position_ids = torch.index_select(position_ids, 1, token_idx - 1).unsqueeze(-1) + position_ids = torch.index_select(position_ids, 1, token_idx - 1)#.unsqueeze(-1) else: position_ids = position_ids[:, -input_ids.shape[1] :] else: diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py index 3142a260eb..e0718b9c79 100644 --- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -405,6 +405,6 @@ def prepare_inputs_for_generation( def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: - return FusedRoPE.apply(q, cos, sin, position_ids), FusedRoPE.apply(k, cos, sin, position_ids) + return FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids), FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) else: return apply_rotary_pos_emb(q, k, cos, sin, position_ids) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index e7ba02d12f..990324a5f5 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -17,6 +17,7 @@ LlamaForCausalLM, LlamaMLP, LlamaModel, + LlamaRMSNorm, apply_rotary_pos_emb, logger, ) @@ -102,7 +103,6 @@ def forward(self, x, y): class GaudiLlamaAttention(LlamaAttention): def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.matmul_qk = Matmul() self.matmul_av = Matmul() self.past_key = None @@ -224,6 +224,7 @@ def pre_attn_forward( kv_seq_len = past_key_value[0][-2] else: kv_seq_len = past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) @@ -354,6 +355,17 @@ def post_mlp_forward(self, x): class GaudiLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.hidden_size = config.hidden_size + + self.self_attn = GaudiLlamaAttention(config=config, layer_idx=layer_idx) + + self.mlp = GaudiLlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) @@ -531,15 +543,22 @@ def forward( ) use_cache = False + #seq_length_with_past = seq_length past_key_values_length = 0 - if use_cache: - if reuse_cache: - past_key_values_length = past_key_values[0][0][2] - else: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + use_legacy_cache = True + do_not_use_new_cache = True # Ignoring new Cache path for HPU + if past_key_values is not None: + if use_cache: + if reuse_cache: + past_key_values_length = past_key_values[0][2] #past_key_values[0][0][2] + else: + if not do_not_use_new_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + #seq_length_with_past = seq_length_with_past + past_key_values_length + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -550,8 +569,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - - key_value_length = seq_length + past_key_values_length if self._use_sdpa and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. @@ -559,12 +576,13 @@ def forward( attention_mask, (batch_size, seq_length), inputs_embeds, - key_value_length, + past_key_value_length, ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, key_value_length + + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) # embed positions @@ -573,9 +591,9 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None + next_decoder_cache = () if do_not_use_new_cache else None - for decoder_layer in self.layers: + for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -585,7 +603,7 @@ def forward( hidden_states, attention_mask, position_ids, - past_key_values, + None if past_key_values is None else past_key_values[layer_idx], output_attentions, use_cache, attn_softmax_bf16=attn_softmax_bf16, @@ -597,7 +615,7 @@ def forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], output_attentions=output_attentions, use_cache=use_cache, token_idx=token_idx, @@ -610,7 +628,7 @@ def forward( hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) @@ -623,7 +641,7 @@ def forward( next_cache = None if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = next_decoder_cache if do_not_use_new_cache else (next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -814,8 +832,8 @@ def prepare_inputs_for_generation( def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: # TODO: remove `.clone()` when SynapseAI v1.15 is released - return FusedRoPE.apply(q, cos.clone(), sin.clone(), position_ids), FusedRoPE.apply( - k, cos.clone(), sin.clone(), position_ids + return FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids), FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids ) else: return apply_rotary_pos_emb(q, k, cos, sin, position_ids) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index b16827d664..e8420901b2 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -27,7 +27,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( +from optimum.habana.transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) @@ -77,13 +77,13 @@ def gaudi_mistral_attn_forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) + shp = past_key_value[0].shape[-2] if type(past_key_value) == type(tuple()) else past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if token_idx is not None: - kv_seq_len = past_key_value[0].shape[-2] + kv_seq_len = shp else: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += shp cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None: if token_idx is not None: past_key_value[0].index_copy_(2, token_idx - 1, key_states) @@ -94,6 +94,7 @@ def gaudi_mistral_attn_forward( cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -234,12 +235,14 @@ def gaudi_mistral_model_forward( use_cache = False past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + use_legacy_cache = True + do_not_use_new_cache = True + if past_key_values is not None: + if use_cache and not do_not_use_new_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -277,19 +280,20 @@ def gaudi_mistral_model_forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None + next_decoder_cache = () if use_cache else None - for decoder_layer in self.layers: + for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) + if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, - past_key_values, + None if past_key_values is None else past_key_values[layer_idx], output_attentions, use_cache, ) @@ -298,7 +302,7 @@ def gaudi_mistral_model_forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], output_attentions=output_attentions, use_cache=use_cache, token_idx=token_idx, @@ -307,7 +311,7 @@ def gaudi_mistral_model_forward( hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) @@ -319,9 +323,8 @@ def gaudi_mistral_model_forward( all_hidden_states += (hidden_states,) next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - + if next_decoder_cache and use_cache: + next_cache = next_decoder_cache if do_not_use_new_cache else (next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast(