diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 274d9ebf3952..12af88997cff 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -372,7 +372,7 @@ def run_call_with_unpacked_inputs(self, *args, **kwargs): # process the inputs and call the wrapped function main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1]) - main_input = fn_args_and_kwargs.pop(main_input_name) + main_input = fn_args_and_kwargs.pop(main_input_name, None) unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs) return func(self, **unpacked_inputs) @@ -423,13 +423,13 @@ def input_processing(func, config, input_ids, **kwargs): ) output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states") - if "past" in kwargs["kwargs_call"] and "past_key_values" in kwargs: + if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names: warnings.warn( "The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", FutureWarning, ) kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past") - elif "past_key_values" in kwargs["kwargs_call"] and "past" in kwargs: + elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names: kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values") if len(kwargs["kwargs_call"]) > 0: @@ -497,6 +497,7 @@ def input_processing(func, config, input_ids, **kwargs): f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}." ) + # Populates any unspecified argument with their default value, according to the signature. for name in parameter_names: if name not in list(output.keys()) and name != "args": output[name] = kwargs.pop(name, signature[name].default) diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index 07d8e812f257..341fb8cf933d 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -694,6 +694,9 @@ def prepare_inputs_for_generation( ): decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None + past_key_values = decoder_inputs.get("past_key_values") + if past_key_values is None: + past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2 input_dict = { "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy "attention_mask": attention_mask, @@ -701,7 +704,7 @@ def prepare_inputs_for_generation( "decoder_input_ids": decoder_inputs["input_ids"], # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]), - "past_key_values": decoder_inputs["past_key_values"], + "past_key_values": past_key_values, "use_cache": use_cache, } return input_dict diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index 23705dc12da8..2065792df21c 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -878,7 +878,7 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_x "input_ids": inputs, "attention_mask": attention_mask, "position_ids": position_ids, - "past_key_values": past, + "past": past, "use_cache": use_cache, } diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index 1d63640af039..8df9be76fc96 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -725,6 +725,9 @@ def prepare_inputs_for_generation( ): decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None + past_key_values = decoder_inputs.get("past_key_values") + if past_key_values is None: + past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2 input_dict = { "pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy "attention_mask": attention_mask, @@ -732,7 +735,7 @@ def prepare_inputs_for_generation( "decoder_input_ids": decoder_inputs["input_ids"], # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]), - "past_key_values": decoder_inputs["past_key_values"], + "past_key_values": past_key_values, "use_cache": use_cache, } return input_dict diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 7494e1397ff0..48d5b3885b2c 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -27,6 +27,7 @@ from huggingface_hub import delete_repo, login from requests.exceptions import HTTPError from transformers import is_tf_available +from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import get_values from transformers.testing_utils import tooslow # noqa: F401 from transformers.testing_utils import ( @@ -80,6 +81,7 @@ TFSampleDecoderOnlyOutput, TFSampleEncoderDecoderOutput, ) + from transformers.modeling_tf_utils import unpack_inputs if _tf_gpu_memory_limit is not None: gpus = tf.config.list_physical_devices("GPU") @@ -1553,6 +1555,68 @@ def test_top_k_top_p_filtering(self): tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12) tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx) + # tests whether the unpack_inputs function behaves as expected + def test_unpack_inputs(self): + class DummyModel: + def __init__(self): + config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False} + self.config = PretrainedConfig(**config_kwargs) + + @unpack_inputs + def call( + self, input_ids=None, past=None, output_attentions=None, output_hidden_states=None, return_dict=None + ): + return input_ids, past, output_attentions, output_hidden_states, return_dict + + dummy_model = DummyModel() + input_ids = tf.constant([0, 1, 2, 3]) + past = tf.constant([4, 5, 6, 7]) + + # test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config. + output = dummy_model.call(input_ids=input_ids, past=past) + tf.debugging.assert_equal(output[0], input_ids) + tf.debugging.assert_equal(output[1], past) + self.assertFalse(output[2]) + self.assertFalse(output[3]) + self.assertFalse(output[4]) + + # test case 2: Same as above, but with positional arguments. + output = dummy_model.call(input_ids, past) + tf.debugging.assert_equal(output[0], input_ids) + tf.debugging.assert_equal(output[1], past) + self.assertFalse(output[2]) + self.assertFalse(output[3]) + self.assertFalse(output[4]) + + # test case 3: We can also pack everything in the first input. + output = dummy_model.call(input_ids={"input_ids": input_ids, "past": past}) + tf.debugging.assert_equal(output[0], input_ids) + tf.debugging.assert_equal(output[1], past) + self.assertFalse(output[2]) + self.assertFalse(output[3]) + self.assertFalse(output[4]) + + # test case 4: Explicit boolean arguments should override the config. + output = dummy_model.call(input_ids=input_ids, past=past, output_attentions=False, return_dict=True) + tf.debugging.assert_equal(output[0], input_ids) + tf.debugging.assert_equal(output[1], past) + self.assertFalse(output[2]) + self.assertFalse(output[3]) + self.assertTrue(output[4]) + + # test case 5: Unexpected arguments should raise an exception. + with self.assertRaises(ValueError): + output = dummy_model.call(input_ids=input_ids, past=past, foo="bar") + + # test case 6: Despite the above, `past_key_values` should be interchangeable with `past` + # (the decorator moves it to `past`, or vice-versa, depending on the signature). + output = dummy_model.call(input_ids=input_ids, past_key_values=past) + tf.debugging.assert_equal(output[0], input_ids) + tf.debugging.assert_equal(output[1], past) + self.assertFalse(output[2]) + self.assertFalse(output[3]) + self.assertFalse(output[4]) + @require_tf @is_staging_test