Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,20 @@
import os
import tempfile
import warnings
from typing import Optional
from typing import Optional, Tuple, Union

import numpy as np
import tensorflow as tf

from ...configuration_utils import PretrainedConfig
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, get_initializer, unpack_inputs
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
get_initializer,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import (
DUMMY_INPUTS,
Expand Down Expand Up @@ -510,22 +517,22 @@ def from_encoder_decoder_pretrained(
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
):
) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
r"""
Returns:

Expand Down Expand Up @@ -720,3 +727,7 @@ def resize_token_embeddings(self, *args, **kwargs):
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
" model.decoder.resize_token_embeddings(...))"
)

def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)