diff --git a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index 5394520220eb..cec99465a0ad 100644 --- a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -5,14 +5,14 @@ import math import warnings -from typing import Optional, Union, Unpack +from typing import Optional, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...utils import TransformersKwargs + from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import ( @@ -39,6 +39,8 @@ from transformers.utils.import_utils import is_torch_fx_available from ...generation import GenerationMixin +from ...processing_utils import Unpack +from ...utils import TransformersKwargs from .configuration_hunyuan_v1_dense import HunYuanDenseV1Config @@ -1329,7 +1331,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **super_kwargs: Unpack[TransformersKwargs] + **super_kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" Args: diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 655a4199a0d0..f851ca9457c7 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -50,6 +50,8 @@ from transformers.utils.import_utils import is_torch_fx_available from ...generation import GenerationMixin +from ...processing_utils import Unpack +from ...utils import TransformersKwargs from .configuration_hunyuan_v1_moe import HunYuanMoeV1Config @@ -739,7 +741,7 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) if self.use_rotary_pos_emb: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) @@ -853,7 +855,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) if self.use_rotary_pos_emb: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -1393,7 +1395,7 @@ def forward( use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + past_key_values_length = past_key_values.get_seq_length() if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -1544,6 +1546,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **super_kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" Args: