Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

import math
import warnings
from typing import Optional, Union, Unpack
from typing import Optional, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...utils import TransformersKwargs

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
Expand All @@ -39,6 +39,8 @@
from transformers.utils.import_utils import is_torch_fx_available

from ...generation import GenerationMixin
from ...processing_utils import Unpack
from ...utils import TransformersKwargs
from .configuration_hunyuan_v1_dense import HunYuanDenseV1Config


Expand Down Expand Up @@ -1329,7 +1331,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**super_kwargs: Unpack[TransformersKwargs]
**super_kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
from transformers.utils.import_utils import is_torch_fx_available

from ...generation import GenerationMixin
from ...processing_utils import Unpack
from ...utils import TransformersKwargs
from .configuration_hunyuan_v1_moe import HunYuanMoeV1Config


Expand Down Expand Up @@ -739,7 +741,7 @@ def forward(
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)

if self.use_rotary_pos_emb:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
Expand Down Expand Up @@ -853,7 +855,7 @@ def forward(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
if self.use_rotary_pos_emb:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
Expand Down Expand Up @@ -1393,7 +1395,7 @@ def forward(
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
past_key_values_length = past_key_values.get_seq_length()

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
Expand Down Expand Up @@ -1544,6 +1546,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**super_kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down