Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def run_call_with_unpacked_inputs(self, *args, **kwargs):

# process the inputs and call the wrapped function
main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1])
main_input = fn_args_and_kwargs.pop(main_input_name)
main_input = fn_args_and_kwargs.pop(main_input_name, None)
unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs)
return func(self, **unpacked_inputs)

Expand Down Expand Up @@ -423,13 +423,13 @@ def input_processing(func, config, input_ids, **kwargs):
)
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")

if "past" in kwargs["kwargs_call"] and "past_key_values" in kwargs:
if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names:
Copy link
Contributor Author

@gante gante Mar 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the root cause for the problem in the decorator -- previously, this function was called inside call, where kwargs contained all keyword arguments (at the very least, with their default value).

The decorator now calls this before call and, because it does not have default values, kwargs was empty. This meant that the past<>past_key_values magic, needed for gpt2+encoder_decoder, was not happening when the decorator was applied on gpt2.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

warnings.warn(
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past")
elif "past_key_values" in kwargs["kwargs_call"] and "past" in kwargs:
elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names:
kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")

if len(kwargs["kwargs_call"]) > 0:
Expand Down Expand Up @@ -497,6 +497,7 @@ def input_processing(func, config, input_ids, **kwargs):
f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
)

# Populates any unspecified argument with their default value, according to the signature.
for name in parameter_names:
if name not in list(output.keys()) and name != "args":
output[name] = kwargs.pop(name, signature[name].default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -694,14 +694,17 @@ def prepare_inputs_for_generation(
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
past_key_values = decoder_inputs.get("past_key_values")
if past_key_values is None:
past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

input_dict = {
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
# TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
"past_key_values": decoder_inputs["past_key_values"],
"past_key_values": past_key_values,
"use_cache": use_cache,
}
return input_dict
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt2/modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_x
"input_ids": inputs,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past,
"past": past,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reverting this

"use_cache": use_cache,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -725,14 +725,17 @@ def prepare_inputs_for_generation(
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
past_key_values = decoder_inputs.get("past_key_values")
if past_key_values is None:
past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2
input_dict = {
"pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
# TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
"past_key_values": decoder_inputs["past_key_values"],
"past_key_values": past_key_values,
"use_cache": use_cache,
}
return input_dict
Expand Down
64 changes: 64 additions & 0 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError
from transformers import is_tf_available
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import get_values
from transformers.testing_utils import tooslow # noqa: F401
from transformers.testing_utils import (
Expand Down Expand Up @@ -80,6 +81,7 @@
TFSampleDecoderOnlyOutput,
TFSampleEncoderDecoderOutput,
)
from transformers.modeling_tf_utils import unpack_inputs

if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
Expand Down Expand Up @@ -1553,6 +1555,68 @@ def test_top_k_top_p_filtering(self):
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)

# tests whether the unpack_inputs function behaves as expected
def test_unpack_inputs(self):
class DummyModel:
def __init__(self):
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
self.config = PretrainedConfig(**config_kwargs)

@unpack_inputs
def call(
self, input_ids=None, past=None, output_attentions=None, output_hidden_states=None, return_dict=None
):
return input_ids, past, output_attentions, output_hidden_states, return_dict

dummy_model = DummyModel()
input_ids = tf.constant([0, 1, 2, 3])
past = tf.constant([4, 5, 6, 7])

# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
output = dummy_model.call(input_ids=input_ids, past=past)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])

# test case 2: Same as above, but with positional arguments.
output = dummy_model.call(input_ids, past)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])

# test case 3: We can also pack everything in the first input.
output = dummy_model.call(input_ids={"input_ids": input_ids, "past": past})
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])

# test case 4: Explicit boolean arguments should override the config.
output = dummy_model.call(input_ids=input_ids, past=past, output_attentions=False, return_dict=True)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertTrue(output[4])

# test case 5: Unexpected arguments should raise an exception.
with self.assertRaises(ValueError):
output = dummy_model.call(input_ids=input_ids, past=past, foo="bar")

# test case 6: Despite the above, `past_key_values` should be interchangeable with `past`
# (the decorator moves it to `past`, or vice-versa, depending on the signature).
output = dummy_model.call(input_ids=input_ids, past_key_values=past)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])


@require_tf
@is_staging_test
Expand Down