diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index 4d3734e78049..88b4fb5ed607 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -37,8 +37,8 @@ TFSequenceSummary, TFSharedEmbeddings, get_initializer, - input_processing, keras_serializable, + unpack_inputs, ) from ...tf_utils import shape_list from ...utils import ( @@ -350,6 +350,7 @@ def _prune_heads(self, heads_to_prune): """ raise NotImplementedError + @unpack_inputs def call( self, input_ids: Optional[TFModelInputType] = None, @@ -368,55 +369,34 @@ def call( training: Optional[bool] = False, **kwargs, ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - past=past, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if inputs["past"] is None: + if past is None: past_length = 0 - inputs["past"] = [None] * len(self.h) + past = [None] * len(self.h) else: - past_length = shape_list(inputs["past"][0][0])[-2] + past_length = shape_list(past[0][0])[-2] - if inputs["position_ids"] is None: - inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) + if position_ids is None: + position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) - if inputs["attention_mask"] is not None: + if attention_mask is not None: # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(inputs["attention_mask"]) - inputs["attention_mask"] = tf.reshape( - inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) - ) + attention_mask_shape = shape_list(attention_mask) + attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -424,24 +404,20 @@ def call( # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. one_cst = tf.constant(1.0) - inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype) - inputs["attention_mask"] = tf.multiply( - tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0) - ) + attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype) + attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0)) # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 - if self.config.add_cross_attention and inputs["encoder_attention_mask"] is not None: + if self.config.add_cross_attention and encoder_attention_mask is not None: # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - inputs["encoder_attention_mask"] = tf.cast( - inputs["encoder_attention_mask"], dtype=inputs["encoder_hidden_states"].dtype - ) - num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"])) + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=encoder_hidden_states.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 @@ -452,66 +428,64 @@ def call( else: encoder_extended_attention_mask = None - inputs["encoder_attention_mask"] = encoder_extended_attention_mask + encoder_attention_mask = encoder_extended_attention_mask # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if inputs["head_mask"] is not None: + if head_mask is not None: raise NotImplementedError else: - inputs["head_mask"] = [None] * self.num_hidden_layers + head_mask = [None] * self.num_hidden_layers # head_mask = tf.constant([0] * self.num_hidden_layers) - inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]]) + position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) - if inputs["inputs_embeds"] is None: - inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding") + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids, mode="embedding") - position_embeds = tf.gather(self.wpe, inputs["position_ids"]) + position_embeds = tf.gather(self.wpe, position_ids) - if inputs["token_type_ids"] is not None: - inputs["token_type_ids"] = tf.reshape( - inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]] - ) - token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding") + if token_type_ids is not None: + token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) + token_type_embeds = self.wte(token_type_ids, mode="embedding") else: token_type_embeds = tf.constant(0.0) - position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype) - token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype) - hidden_states = inputs["inputs_embeds"] + position_embeds + token_type_embeds - hidden_states = self.drop(hidden_states, training=inputs["training"]) + position_embeds = tf.cast(position_embeds, dtype=inputs_embeds.dtype) + token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype) + hidden_states = inputs_embeds + position_embeds + token_type_embeds + hidden_states = self.drop(hidden_states, training=training) output_shape = input_shape + [shape_list(hidden_states)[-1]] - presents = () if inputs["use_cache"] else None - all_attentions = () if inputs["output_attentions"] else None - all_cross_attentions = () if inputs["output_attentions"] and self.config.add_cross_attention else None - all_hidden_states = () if inputs["output_hidden_states"] else None - for i, (block, layer_past) in enumerate(zip(self.h, inputs["past"])): - if inputs["output_hidden_states"]: + presents = () if use_cache else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past)): + if output_hidden_states: all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) outputs = block( hidden_states, layer_past, - inputs["attention_mask"], - inputs["head_mask"][i], - inputs["encoder_hidden_states"], - inputs["encoder_attention_mask"], - inputs["use_cache"], - inputs["output_attentions"], - training=inputs["training"], + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + training=training, ) hidden_states, present = outputs[:2] - if inputs["use_cache"]: + if use_cache: presents = presents + (present,) - if inputs["output_attentions"]: + if output_attentions: all_attentions = all_attentions + (outputs[2],) if self.config.add_cross_attention and encoder_hidden_states is not None: all_cross_attentions = all_cross_attentions + (outputs[3],) @@ -520,15 +494,15 @@ def call( hidden_states = tf.reshape(hidden_states, output_shape) # Add last hidden state - if inputs["output_hidden_states"]: + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if inputs["output_attentions"]: + if output_attentions: # let the number of heads free (-1) so we can extract attention even after head pruning attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) - if not inputs["return_dict"]: + if not return_dict: return tuple( v for v in [hidden_states, presents, all_hidden_states, all_attentions, all_cross_attentions] @@ -732,6 +706,7 @@ def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.transformer = TFGPT2MainLayer(config, name="transformer") + @unpack_inputs @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -777,9 +752,8 @@ def call( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past`). Set to `False` during training, `True` during generation """ - inputs = input_processing( - func=self.call, - config=self.config, + + outputs = self.transformer( input_ids=input_ids, past=past, attention_mask=attention_mask, @@ -794,23 +768,6 @@ def call( output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - outputs = self.transformer( - input_ids=inputs["input_ids"], - past=inputs["past"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) return outputs @@ -938,6 +895,7 @@ def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current return model_kwargs + @unpack_inputs @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -987,9 +945,8 @@ def call( Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., config.vocab_size - 1]`. """ - inputs = input_processing( - func=self.call, - config=self.config, + + transformer_outputs = self.transformer( input_ids=input_ids, past=past, attention_mask=attention_mask, @@ -1003,37 +960,19 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - transformer_outputs = self.transformer( - input_ids=inputs["input_ids"], - past=inputs["past"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) hidden_states = transformer_outputs[0] logits = self.transformer.wte(hidden_states, mode="linear") loss = None - if inputs["labels"] is not None: + if labels is not None: # shift labels to the left and cut last logit token shifted_logits = logits[:, :-1] - labels = inputs["labels"][:, 1:] + labels = labels[:, 1:] loss = self.hf_compute_loss(labels, shifted_logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1081,6 +1020,7 @@ def __init__(self, config, *inputs, **kwargs): config, initializer_range=config.initializer_range, name="multiple_choice_head" ) + @unpack_inputs @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) def call( @@ -1133,64 +1073,40 @@ def call( >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) >>> lm_prediction_scores, mc_prediction_scores = outputs[:2] ```""" - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - past=past, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - mc_token_ids=mc_token_ids, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None: - input_shapes = shape_list(inputs["input_ids"]) + if input_ids is not None: + input_shapes = shape_list(input_ids) else: - input_shapes = shape_list(inputs["inputs_embeds"])[:-1] + input_shapes = shape_list(inputs_embeds)[:-1] seq_length = input_shapes[-1] - flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None - flat_attention_mask = ( - tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None - ) - flat_token_type_ids = ( - tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None - ) - flat_position_ids = ( - tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None - ) + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None transformer_outputs = self.transformer( input_ids=flat_input_ids, - past=inputs["past"], + past=past, attention_mask=flat_attention_mask, token_type_ids=flat_token_type_ids, position_ids=flat_position_ids, - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], + head_mask=head_mask, + inputs_embeds=inputs_embeds, encoder_hidden_states=None, encoder_attention_mask=None, - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) hidden_states = transformer_outputs[0] hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) lm_logits = self.transformer.wte(hidden_states, mode="linear") - mc_logits = self.multiple_choice_head(hidden_states, inputs["mc_token_ids"], training=inputs["training"]) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) mc_logits = tf.squeeze(mc_logits, axis=-1) - if not inputs["return_dict"]: + if not return_dict: return (lm_logits, mc_logits) + transformer_outputs[1:] return TFGPT2DoubleHeadsModelOutput( @@ -1256,6 +1172,7 @@ def __init__(self, config, *inputs, **kwargs): ) self.transformer = TFGPT2MainLayer(config, name="transformer") + @unpack_inputs @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1285,9 +1202,7 @@ def call( Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., config.vocab_size - 1]`. """ - inputs = input_processing( - func=self.call, - config=self.config, + transformer_outputs = self.transformer( input_ids=input_ids, past=past, attention_mask=attention_mask, @@ -1299,24 +1214,7 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - - transformer_outputs = self.transformer( - input_ids=inputs["input_ids"], - past=inputs["past"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) hidden_states = transformer_outputs[0] @@ -1326,12 +1224,12 @@ def call( if self.config.pad_token_id is None: sequence_lengths = -1 else: - if inputs["input_ids"] is not None: + if input_ids is not None: sequence_lengths = ( tf.reduce_sum( tf.cast( - tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), - dtype=inputs["input_ids"].dtype, + tf.math.not_equal(input_ids, self.config.pad_token_id), + dtype=input_ids.dtype, ), -1, keepdims=False, @@ -1347,7 +1245,7 @@ def call( ) loss = None - if inputs["labels"] is not None: + if labels is not None: assert ( self.config.pad_token_id is not None or logits_shape[0] == 1 ), "Cannot handle batch sizes > 1 if no padding token is defined." @@ -1355,12 +1253,10 @@ def call( if not tf.is_tensor(sequence_lengths): in_logits = logits[0 : logits_shape[0], sequence_lengths] - loss = self.hf_compute_loss( - tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]) - ) + loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(in_logits, [-1, self.num_labels])) pooled_logits = in_logits if in_logits is not None else logits - if not inputs["return_dict"]: + if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output