Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

complete t5 more output #3370

Merged
merged 13 commits into from
Sep 29, 2022
120 changes: 120 additions & 0 deletions paddlenlp/transformers/model_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,3 +724,123 @@ class CausalLMOutputWithCrossAttentions(ModelOutput):
hidden_states: Optional[Tuple[paddle.Tensor]] = None
attentions: Optional[Tuple[paddle.Tensor]] = None
cross_attentions: Optional[Tuple[paddle.Tensor]] = None


@dataclass
class Seq2SeqModelOutput(ModelOutput):
"""
Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
decoding.

Args:
last_hidden_state (`paddle.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以调整下文档格式

Sequence of hidden-states at the output of the last layer of the decoder of the model.

If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(tuple(paddle.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
decoder_attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (`paddle.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
encoder_attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""

last_hidden_state: paddle.Tensor = None
past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[paddle.Tensor]] = None
decoder_attentions: Optional[Tuple[paddle.Tensor]] = None
cross_attentions: Optional[Tuple[paddle.Tensor]] = None
encoder_last_hidden_state: Optional[paddle.Tensor] = None
encoder_hidden_states: Optional[Tuple[paddle.Tensor]] = None
encoder_attentions: Optional[Tuple[paddle.Tensor]] = None

@dataclass
class Seq2SeqLMOutput(ModelOutput):
"""
Base class for sequence-to-sequence language models outputs.

Args:
loss (`paddle.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss.
logits (`paddle.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(paddle.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (`paddle.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""

loss: Optional[paddle.Tensor] = None
logits: paddle.Tensor = None
past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[paddle.Tensor]] = None
decoder_attentions: Optional[Tuple[paddle.Tensor]] = None
cross_attentions: Optional[Tuple[paddle.Tensor]] = None
encoder_last_hidden_state: Optional[paddle.Tensor] = None
encoder_hidden_states: Optional[Tuple[paddle.Tensor]] = None
encoder_attentions: Optional[Tuple[paddle.Tensor]] = None
132 changes: 105 additions & 27 deletions paddlenlp/transformers/t5/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@

from ..model_utils import PretrainedModel, register_base_model
from ..nezha.modeling import ACT2FN
from ..model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqModelOutput,
Seq2SeqLMOutput,
BaseModelOutput,
ModelOutput,
)

__all__ = [
'T5Model', "T5PretrainedModel", 'T5ForConditionalGeneration',
Expand Down Expand Up @@ -944,7 +951,8 @@ def forward(self,
cache=None,
use_cache=False,
output_attentions=False,
output_hidden_states=False):
output_hidden_states=False,
return_dict=False):
assert input_ids is not None, "input_ids can not be None"
input_shape = input_ids.shape
input_ids = input_ids.reshape(shape=[-1, input_shape[-1]])
Expand Down Expand Up @@ -1051,13 +1059,22 @@ def forward(self,
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states, )

return tuple(v for v in [
hidden_states,
present_key_value_states,
all_hidden_states,
all_attentions,
all_cross_attentions,
] if v is not None)
if not return_dict:
return tuple(v for v in [
hidden_states,
present_key_value_states,
all_hidden_states,
all_attentions,
all_cross_attentions,
] if v is not None)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)

def get_extended_attention_mask(self, attention_mask, input_shape):
if attention_mask.ndim == 3:
Expand Down Expand Up @@ -1293,7 +1310,8 @@ def forward(self,
cache=None,
use_cache=True,
output_attentions=False,
output_hidden_states=False):
output_hidden_states=False,
return_dict=False):
r"""
The T5Model forward method, overrides the `__call__()` special method.

Expand Down Expand Up @@ -1343,8 +1361,16 @@ def forward(self,
output_hidden_states (bool, optional):
Whether or not to return the output of all hidden layers.
Defaults to `False`.
return_dict (bool, optional):
Whether or not to return a class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`. If `False`, the output
will be a tuple of tensors. Defaults to `False`.


Returns:
An instance of :class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput` if `return_dict=True`.
Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of
:class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`.

tuple: Returns tuple (`last_hidden_state`, `cache`, `decoder_hidden_states`, `decoder_attentions`,
`cross_attentions`, `encoder_last_hidden_state`, `encoder_hidden_states`, `encoder_attentions`)

Expand Down Expand Up @@ -1419,8 +1445,10 @@ def forward(self,
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states)

output_hidden_states=output_hidden_states,
return_dict=return_dict)
elif return_dict and not isinstance(encoder_output, ModelOutput):
encoder_output = convert_encoder_output(encoder_output)
hidden_states = encoder_output[0]

# Decode
Expand All @@ -1432,9 +1460,22 @@ def forward(self,
encoder_attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states)

return decoder_outputs + encoder_output
output_hidden_states=output_hidden_states,
return_dict=return_dict)

if not return_dict:
return decoder_outputs + encoder_output

return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_output.last_hidden_state,
encoder_hidden_states=encoder_output.hidden_states,
encoder_attentions=encoder_output.attentions,
)


class T5ForConditionalGeneration(T5PretrainedModel):
Expand Down Expand Up @@ -1490,7 +1531,8 @@ def forward(self,
labels=None,
use_cache=True,
output_attentions=False,
output_hidden_states=False):
output_hidden_states=False,
return_dict=False):
r"""

Args:
Expand Down Expand Up @@ -1518,8 +1560,14 @@ def forward(self,
See :class:`T5Model`.
output_hidden_states (bool, optional):
See :class:`T5Model`.
return_dict (bool, optional):
See :class:`T5Model`.

Returns:
An instance of :class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput` if `return_dict=True`.
Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of
:class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput`.

tuple: Returns tuple (`loss`, `logits`, `cache`, `decoder_hidden_states`, `decoder_attentions`,
`cross_attentions`, `encoder_last_hidden_state`, `encoder_hidden_states`, `encoder_attentions`)

Expand Down Expand Up @@ -1581,12 +1629,13 @@ def forward(self,
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states)
output_hidden_states=output_hidden_states,
return_dict=return_dict)

if isinstance(encoder_output, (tuple, list)):
hidden_states = encoder_output[0]
else:
hidden_states = encoder_output
# encoder_output could be a Tensor, tuple or ModelOutput
if isinstance(encoder_output, paddle.Tensor):
encoder_output = (encoder_output, )
hidden_states = encoder_output[0]

if labels is not None and decoder_input_ids is None:
# get decoder inputs from shifting lm labels to the right
Expand All @@ -1610,7 +1659,8 @@ def forward(self,
encoder_attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states)
output_hidden_states=output_hidden_states,
return_dict=return_dict)

sequence_output = decoder_outputs[0]

Expand All @@ -1631,11 +1681,26 @@ def forward(self,
loss = loss_fct(lm_logits.reshape(shape=[-1, lm_logits.shape[-1]]),
labels.flatten())

if not isinstance(encoder_output, (list, tuple)):
encoder_output = (encoder_output, )

output = (lm_logits, ) + decoder_outputs[1:] + encoder_output
return ((loss, ) + output) if loss is not None else output
if not return_dict:
# 元组相加
output = (lm_logits, ) + decoder_outputs[1:] + encoder_output[0:]

return ((loss, ) + output) if loss is not None else output

if not isinstance(encoder_output, ModelOutput):
encoder_output = convert_encoder_output(encoder_output)

return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_output.last_hidden_state,
encoder_hidden_states=encoder_output.hidden_states,
encoder_attentions=encoder_output.attentions,
)

@staticmethod
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
Expand Down Expand Up @@ -1817,6 +1882,7 @@ def forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = False,
):
encoder_outputs = self.encoder(
input_ids=input_ids,
Expand All @@ -1827,9 +1893,21 @@ def forward(
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
return_dict=return_dict)

return encoder_outputs


T5EncoderModel.base_model_class = T5EncoderModel


def convert_encoder_output(encoder_output):
"""
Convert encoder_output which type is tuple to an instance of BaseModelOutput.
args: encoder_output = (last_hidden_state, hidden_states, attentions)
"""
return BaseModelOutput(
last_hidden_state=encoder_output[0],
hidden_states=encoder_output[1] if len(encoder_output) > 1 else None,
attentions=encoder_output[2] if len(encoder_output) > 2 else None,
)