Skip to content
31 changes: 18 additions & 13 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,10 @@ def booleans_processing(config, **kwargs):

if tf.executing_eagerly():
# Pure conv models (such as ConvNext) do not have `output_attentions`
final_booleans["output_attentions"] = kwargs.get("output_attentions", None)
if final_booleans["output_attentions"] is None:
final_booleans["output_attentions"] = config.output_attentions
if "output_attentions" in kwargs:
final_booleans["output_attentions"] = (
kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions
)
Comment on lines +316 to +319
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous version was passing down final_booleans["output_attentions"]=False in pure conv models, which would set the output_attentions argument to False. The new version results in no argument, which is the desired behavior.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment that "output_attentions" will be in kwargs, with a value of None if unset? That change made me pause for a couple of minutes.

final_booleans["output_hidden_states"] = (
kwargs["output_hidden_states"]
if kwargs["output_hidden_states"] is not None
Expand All @@ -329,7 +330,9 @@ def booleans_processing(config, **kwargs):
kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None)
)
else:
final_booleans["output_attentions"] = config.output_attentions
# Pure conv models (such as ConvNext) do not have `output_attentions`
if "output_attentions" in kwargs:
final_booleans["output_attentions"] = config.output_attentions
final_booleans["output_hidden_states"] = config.output_hidden_states

if kwargs.get("return_dict", None) not in (None, True):
Expand Down Expand Up @@ -373,7 +376,7 @@ def run_call_with_unpacked_inputs(self, *args, **kwargs):
# process the inputs and call the wrapped function
main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1])
main_input = fn_args_and_kwargs.pop(main_input_name, None)
unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs)
unpacked_inputs = _input_processing(func, self.config, main_input, **fn_args_and_kwargs)
return func(self, **unpacked_inputs)

# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
Expand All @@ -384,7 +387,7 @@ def run_call_with_unpacked_inputs(self, *args, **kwargs):
return run_call_with_unpacked_inputs


def input_processing(func, config, input_ids, **kwargs):
def _input_processing(func, config, input_ids, **kwargs):
"""
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
has to be named accordingly to the parameters name, i.e. `input_ids = tf.keras.Input(shape=(128,), dtype='int32',
Expand All @@ -402,7 +405,7 @@ def input_processing(func, config, input_ids, **kwargs):
Two lists, one for the missing layers, and another one for the unexpected layers.
"""
signature = dict(inspect.signature(func).parameters)
signature.pop("kwargs", None)
has_kwargs = bool(signature.pop("kwargs", None))
signature.pop("self", None)
parameter_names = list(signature.keys())
output = {}
Expand Down Expand Up @@ -432,12 +435,14 @@ def input_processing(func, config, input_ids, **kwargs):
elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names:
kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")

if len(kwargs["kwargs_call"]) > 0:
raise ValueError(
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
)

kwargs.pop("kwargs_call")
if has_kwargs:
output["kwargs"] = kwargs.pop("kwargs_call", {})
else:
if len(kwargs["kwargs_call"]) > 0:
raise ValueError(
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
)
kwargs.pop("kwargs_call")
Comment on lines +440 to +447
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encoder_decoder models want the kwargs, all other models will discard them (and throw an error if they are not empty)


for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None:
Expand Down
8 changes: 0 additions & 8 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,6 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:

if input_ids is not None and inputs_embeds is not None:
Expand Down Expand Up @@ -785,7 +784,6 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
outputs = self.albert(
input_ids=input_ids,
Expand Down Expand Up @@ -854,7 +852,6 @@ def call(
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
sentence_order_label: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFAlbertForPreTrainingOutput, Tuple[tf.Tensor]]:
r"""
Return:
Expand Down Expand Up @@ -976,7 +973,6 @@ def call(
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -1064,7 +1060,6 @@ def call(
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Expand Down Expand Up @@ -1158,7 +1153,6 @@ def call(
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -1244,7 +1238,6 @@ def call(
start_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
end_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
r"""
start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Expand Down Expand Up @@ -1355,7 +1348,6 @@ def call(
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,6 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
"""
Args:
Expand Down Expand Up @@ -834,7 +833,6 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
r"""
Args:
Expand Down Expand Up @@ -1273,7 +1271,6 @@ def call(
return_dict: Optional[bool] = None,
labels: Optional[tf.Tensor] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down
9 changes: 0 additions & 9 deletions src/transformers/models/bert/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,6 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:

if not self.config.is_decoder:
Expand Down Expand Up @@ -1067,7 +1066,6 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
r"""
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -1174,7 +1172,6 @@ def call(
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
next_sentence_label: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFBertForPreTrainingOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -1302,7 +1299,6 @@ def call(
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -1520,7 +1516,6 @@ def call(
return_dict: Optional[bool] = None,
next_sentence_label: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFNextSentencePredictorOutput, Tuple[tf.Tensor]]:
r"""
Return:
Expand Down Expand Up @@ -1628,7 +1623,6 @@ def call(
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
Expand Down Expand Up @@ -1723,7 +1717,6 @@ def call(
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
Expand Down Expand Up @@ -1857,7 +1850,6 @@ def call(
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -1949,7 +1941,6 @@ def call(
start_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
end_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
r"""
start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/blenderbot/modeling_tf_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,6 @@ def call(
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -823,7 +822,6 @@ def call(
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r"""
Args:
Expand Down Expand Up @@ -1276,7 +1274,6 @@ def call(
return_dict: Optional[bool] = None,
labels: Optional[tf.Tensor] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[Tuple[tf.Tensor], TFSeq2SeqLMOutput]:
r"""
labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,6 @@ def call(
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -827,7 +826,6 @@ def call(
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r"""
Args:
Expand Down Expand Up @@ -1253,7 +1251,6 @@ def call(
return_dict: Optional[bool] = None,
labels: Optional[tf.Tensor] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[Tuple[tf.Tensor], TFSeq2SeqLMOutput]:
r"""
labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down
12 changes: 0 additions & 12 deletions src/transformers/models/clip/modeling_tf_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ def call(
output_hidden_states: bool,
return_dict: bool,
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
input_shape = shape_list(input_ids)

Expand Down Expand Up @@ -593,7 +592,6 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
if input_ids is None:
raise ValueError("You have to specify input_ids")
Expand Down Expand Up @@ -632,7 +630,6 @@ def call(
output_hidden_states: bool,
return_dict: bool,
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:

embedding_output = self.embeddings(pixel_values=pixel_values)
Expand Down Expand Up @@ -683,7 +680,6 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:

if pixel_values is None:
Expand Down Expand Up @@ -762,7 +758,6 @@ def get_text_features(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> tf.Tensor:

if input_ids is None:
Expand Down Expand Up @@ -796,7 +791,6 @@ def get_image_features(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> tf.Tensor:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
Expand Down Expand Up @@ -826,7 +820,6 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFCLIPOutput, Tuple[tf.Tensor]]:

if input_ids is None:
Expand Down Expand Up @@ -1058,7 +1051,6 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
r"""
Returns:
Expand Down Expand Up @@ -1153,7 +1145,6 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
r"""
Returns:
Expand Down Expand Up @@ -1258,7 +1249,6 @@ def get_text_features(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> tf.Tensor:
r"""
Returns:
Expand Down Expand Up @@ -1297,7 +1287,6 @@ def get_image_features(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> tf.Tensor:
r"""
Returns:
Expand Down Expand Up @@ -1345,7 +1334,6 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFCLIPOutput, Tuple[tf.Tensor]]:
r"""
Returns:
Expand Down
Loading