Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,46 @@ def booleans_processing(config, **kwargs):
return final_booleans


def unpack_inputs(func):
"""
Decorator that processes the inputs to a Keras layer, passing them to the layer as keyword arguments. This enables
downstream use of the inputs by their variable name, even if they arrive packed as a dictionary in the first input
(common case in Keras).

Args:
func (`callable`):
The callable function of the TensorFlow model.

Returns:
A callable that wraps the original `func` with the behavior described above.
"""

original_signature = inspect.signature(func)

@functools.wraps(func)
def run_call_with_unpacked_inputs(self, *args, **kwargs):
# isolates the actual `**kwargs` for the decorated function
kwargs_call = {key: val for key, val in kwargs.items() if key not in dict(original_signature.parameters)}
fn_args_and_kwargs = {key: val for key, val in kwargs.items() if key not in kwargs_call}
fn_args_and_kwargs.update({"kwargs_call": kwargs_call})

# move any arg into kwargs, if they exist
fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))

# process the inputs and call the wrapped function
main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1])
main_input = fn_args_and_kwargs.pop(main_input_name)
unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs)
return func(self, **unpacked_inputs)

# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
# function does not follow wrapper chains (i.e. ignores `functools.wraps()`), meaning that without the line below
# Keras would attempt to check the first argument against the literal signature of the wrapper.
run_call_with_unpacked_inputs.__signature__ = original_signature

return run_call_with_unpacked_inputs


def input_processing(func, config, input_ids, **kwargs):
"""
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
Expand Down
Loading