Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import PretrainedConfig
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.cache_utils import Cache, StaticCache
from transformers.generation import GenerationMixin
from transformers.integrations.deepspeed import is_deepspeed_available
from transformers.modeling_attn_mask_utils import (
Expand Down Expand Up @@ -1794,13 +1794,9 @@ def forward(
else:
past_seen_tokens = past_key_values[0][0][2]
else:
if use_new_cache:
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
else:
if past_key_values[0] is not None: ##added for (None, None)
past_seen_tokens = past_key_values[0][0].shape[2]
# HPU uses legacy cache path (use_new_cache = False)
if past_key_values[0] is not None: ##added for (None, None)
past_seen_tokens = past_key_values[0][0].shape[2]

if ignore_cache_position is False:
if cache_position is None:
Expand Down
11 changes: 3 additions & 8 deletions optimum/habana/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import torch
import torch.nn.functional as F
from transformers.cache_utils import Cache, DynamicCache
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
Expand Down Expand Up @@ -623,13 +623,8 @@ def forward(
if reuse_cache:
past_seen_tokens = past_key_values[0][0][2]
else:
if use_new_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_usable_length(seq_length)
else:
past_seen_tokens = past_key_values[0][0].shape[2]
# HPU uses legacy cache path (use_new_cache = False)
past_seen_tokens = past_key_values[0][0].shape[2]

cache_position = None

Expand Down
10 changes: 3 additions & 7 deletions optimum/habana/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import torch
import torch.nn.functional as F
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2Attention,
Expand Down Expand Up @@ -600,12 +600,8 @@ def forward(
else:
past_seen_tokens = past_key_values[0][0][2]
else:
if use_new_cache:
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
else:
past_seen_tokens = past_key_values[0][0].shape[2]
# HPU uses legacy cache path (use_new_cache = False)
past_seen_tokens = past_key_values[0][0].shape[2]

if ignore_cache_position is False:
if cache_position is None:
Expand Down
10 changes: 3 additions & 7 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch.distributed.distributed_c10d import ProcessGroup
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.models.llama.modeling_llama import (
Expand Down Expand Up @@ -1247,12 +1247,8 @@ def forward(
else:
past_seen_tokens = past_key_values[0][0][2]
else:
if use_new_cache:
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
else:
past_seen_tokens = past_key_values[0][0].shape[2]
# HPU uses legacy cache path (use_new_cache = False)
past_seen_tokens = past_key_values[0][0].shape[2]

if ignore_cache_position is False:
if cache_position is None:
Expand Down
10 changes: 3 additions & 7 deletions optimum/habana/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.cache_utils import Cache
from transformers.integrations.deepspeed import is_deepspeed_available
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
Expand Down Expand Up @@ -591,12 +591,8 @@ def forward(
if reuse_cache:
past_key_values_length = past_key_values[0][0][2]
else:
if use_new_cache:
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length()
else:
past_key_values_length = past_key_values[0][0].shape[2]
# HPU uses legacy cache path (use_new_cache = False)
past_key_values_length = past_key_values[0][0].shape[2]

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
Expand Down
10 changes: 3 additions & 7 deletions optimum/habana/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Optional, Union

import torch
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.cache_utils import Cache
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
Expand Down Expand Up @@ -809,12 +809,8 @@ def forward(
else:
past_seen_tokens = past_key_values[0][0][2]
else:
if use_new_cache:
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
else:
past_seen_tokens = past_key_values[0][0].shape[2]
# HPU uses legacy cache path (use_new_cache = False)
past_seen_tokens = past_key_values[0][0].shape[2]

if ignore_cache_position is False:
if cache_position is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import habana_frameworks.torch.core as htcore
import torch
import torch.nn.functional as F
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.cache_utils import Cache, StaticCache
from transformers.integrations.deepspeed import is_deepspeed_available
from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
Expand Down Expand Up @@ -869,13 +869,9 @@ def forward(
else:
past_seen_tokens = past_key_values[0][0][2]
else:
if use_new_cache:
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
else:
if past_key_values[0] is not None: ##added for (None, None)
past_seen_tokens = past_key_values[0][0].shape[2]
# HPU uses legacy cache path (use_new_cache = False)
if past_key_values[0] is not None: ##added for (None, None)
past_seen_tokens = past_key_values[0][0].shape[2]

if ignore_cache_position is False:
if cache_position is None:
Expand Down
10 changes: 3 additions & 7 deletions optimum/habana/transformers/models/qwen3/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Optional, Union

import torch
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.cache_utils import Cache
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
Expand Down Expand Up @@ -804,12 +804,8 @@ def forward(
else:
past_seen_tokens = past_key_values[0][0][2]
else:
if use_new_cache:
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
else:
past_seen_tokens = past_key_values[0][0].shape[2]
# HPU uses legacy cache path (use_new_cache = False)
past_seen_tokens = past_key_values[0][0].shape[2]

if ignore_cache_position is False:
if cache_position is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.cache_utils import Cache, StaticCache
from transformers.integrations.deepspeed import is_deepspeed_available
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_outputs import (
Expand Down Expand Up @@ -906,13 +906,9 @@ def forward(
else:
past_seen_tokens = past_key_values[0][0][2]
else:
if use_new_cache:
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
else:
if past_key_values[0] is not None: ##added for (None, None)
past_seen_tokens = past_key_values[0][0].shape[2]
# HPU uses legacy cache path (use_new_cache = False)
if past_key_values[0] is not None: ##added for (None, None)
past_seen_tokens = past_key_values[0][0].shape[2]

if ignore_cache_position is False:
if cache_position is None:
Expand Down
Loading