diff --git a/check_t5_against_hf.py b/check_t5_against_hf.py new file mode 100755 index 000000000000..fe1ec9d570cd --- /dev/null +++ b/check_t5_against_hf.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +import os + + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # or any {'0', '1', '2'} + +import t5 # noqa: E402 +from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary # noqa: E402 +from transformers import T5Tokenizer # noqa: E402 +from transformers.convert_t5_v1_1_original_tf_checkpoint_to_pytorch import ( # noqa: E402 + convert_tf_checkpoint_to_pytorch, +) +from transformers.modeling_t5v2 import T5Config, T5v2ForConditionalGeneration # noqa: E402 + + +path_to_tf_checkpoint = "/home/patrick/hugging_face/mt5/mt5_mesh_tf" + + +tok = T5Tokenizer.from_pretrained(path_to_tf_checkpoint + "/sentencepiece.model") +tok.save_pretrained(path_to_tf_checkpoint) +config = T5Config.from_pretrained("t5-small") +config.d_ff = 1024 +config.num_decoder_layers = 8 +config.num_layers = 8 +config.num_heads = 6 +# comment this line out if only checkpoints for T5v1.1 should be checked +config.vocab_size = 250112 + +config.save_pretrained(path_to_tf_checkpoint) + +convert_tf_checkpoint_to_pytorch(path_to_tf_checkpoint, path_to_tf_checkpoint + "/config.json", path_to_tf_checkpoint) + +t5_model = t5.models.MtfModel( + model_dir=path_to_tf_checkpoint, + batch_size=1, + tpu=None, + sequence_length={"inputs": 64, "targets": 64}, +) + +vocab_model_path = path_to_tf_checkpoint + "/sentencepiece.model" + +# for T5v1.1 one should set `extra_ids=100`. +vocab = SentencePieceVocabulary(vocab_model_path, extra_ids=0) + +score = t5_model.score( + inputs=["Hello there. Let's put more words in more languages than I originally thought."], + targets=["Hi I am"], + vocabulary=vocab, +) + +model = T5v2ForConditionalGeneration.from_pretrained(path_to_tf_checkpoint, return_dict=True) + +input_ids = tok("Hello there", return_tensors="pt").input_ids +labels = tok("Hi I am", return_tensors="pt").input_ids + +# input_ids and labels are ok! +loss = model(input_ids, labels=labels).loss +mesh_tf_loss = -(labels.shape[-1] * loss.item()) + +if mesh_tf_loss - score[0][0] < 1e-4: + print("Success!") +else: + print(f"Fail. Mesh TF {mesh_tf_loss} vs. {score[0][0]}") diff --git a/src/transformers/configuration_t5v2.py b/src/transformers/configuration_t5v2.py new file mode 100644 index 000000000000..21fa64ef1f22 --- /dev/null +++ b/src/transformers/configuration_t5v2.py @@ -0,0 +1,132 @@ +# coding=utf-8 +# Copyright 2010, The T5v2 Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" T5v2 model configuration """ + +from .configuration_utils import PretrainedConfig +from .utils import logging + + +logger = logging.get_logger(__name__) + +T5v2_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "t5-small": "https://huggingface.co/t5-small/resolve/main/config.json", + "t5-base": "https://huggingface.co/t5-base/resolve/main/config.json", + "t5-large": "https://huggingface.co/t5-large/resolve/main/config.json", + "t5-3b": "https://huggingface.co/t5-3b/resolve/main/config.json", + "t5-11b": "https://huggingface.co/t5-11b/resolve/main/config.json", +} + + +class T5v2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.T5v2Model` or a + :class:`~transformers.TFT5v2Model`. It is used to instantiate a T5v2 model according to the specified arguments, + defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration + to that of the T5v2 `t5-small `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + Arguments: + vocab_size (:obj:`int`, `optional`, defaults to 32128): + Vocabulary size of the T5v2 model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.T5v2Model` or + :class:`~transformers.TFT5v2Model`. + n_positions (:obj:`int`, `optional`, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + d_model (:obj:`int`, `optional`, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (:obj:`int`, `optional`, defaults to 64): + Size of the key, query, value projections per attention head. :obj:`d_kv` has to be equal to :obj:`d_model + // num_heads`. + d_ff (:obj:`int`, `optional`, defaults to 2048): + Size of the intermediate feed forward layer in each :obj:`T5v2Block`. + num_layers (:obj:`int`, `optional`, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (:obj:`int`, `optional`): + Number of hidden layers in the Transformer decoder. Will use the same value as :obj:`num_layers` if not + set. + num_heads (:obj:`int`, `optional`, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (:obj:`int`, `optional`, defaults to 32): + The number of buckets to use for each attention layer. + dropout_rate (:obj:`float`, `optional`, defaults to 0.1): + The ratio for all dropout layers. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (:obj:`float`, `optional`, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + """ + model_type = "t5" + + def __init__( + self, + vocab_size=32128, + n_positions=512, + d_model=512, + d_kv=64, + d_ff=2048, + num_layers=6, + num_decoder_layers=None, + num_heads=8, + relative_attention_num_buckets=32, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + is_encoder_decoder=True, + pad_token_id=0, + eos_token_id=1, + tie_word_embeddings=False, + **kwargs + ): + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.n_positions = n_positions + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + + @property + def max_position_embeddings(self): + return self.n_positions + + @property + def hidden_size(self): + return self.d_model + + @property + def num_attention_heads(self): + return self.num_heads + + @property + def num_hidden_layers(self): + return self.num_layers diff --git a/src/transformers/convert_t5_v1_1_original_tf_checkpoint_to_pytorch.py b/src/transformers/convert_t5_v1_1_original_tf_checkpoint_to_pytorch.py new file mode 100755 index 000000000000..7ac47591c184 --- /dev/null +++ b/src/transformers/convert_t5_v1_1_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# Copyright 2018 The T5 authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert T5 checkpoint.""" + + +import argparse + +from transformers.modeling_t5v2 import T5Config, T5v2ForConditionalGeneration, load_tf_weights_in_t5 +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config = T5Config.from_json_file(config_file) + print("Building PyTorch model from configuration: {}".format(str(config))) + model = T5v2ForConditionalGeneration(config) + + # Load weights from tf checkpoint + load_tf_weights_in_t5(model, config, tf_checkpoint_path) + + # Save pytorch-model + print("Save PyTorch model to {}".format(pytorch_dump_path)) + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained T5 model. \n" + "This specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/modeling_t5v2.py b/src/transformers/modeling_t5v2.py new file mode 100644 index 000000000000..4c7435847cef --- /dev/null +++ b/src/transformers/modeling_t5v2.py @@ -0,0 +1,1312 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5v2 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5v2v1_1 model. """ + + +import copy +import math +import os +import warnings + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss + +from .activations import ACT2FN +from .configuration_t5v2 import T5v2Config as T5Config +from .file_utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from .modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from .utils import logging + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_TOKENIZER_FOR_DOC = "T5v2Tokenizer" + +#################################################### +# This dict contains shortcut names and associated url +# for the pretrained weights provided with the models +#################################################### +T5v2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + # See all T5v2 models at https://huggingface.co/models?filter=t5 +] + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info("Skipping {}".format("/".join(name))) + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info("Skipping {}".format("/".join(name))) + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi": + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info("Skipping {}".format("/".join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info("Transposing numpy weight of shape {} for {}".format(array.shape, name)) + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched for {name}" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys()))) + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module) +#################################################### + + +class T5v2LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5v2 style No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + # layer norm should always be calculated in float32 + variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) + x = x / torch.sqrt(variance + self.variance_epsilon) + + if self.weight.dtype == torch.float16: + x = x.to(torch.float16) + return self.weight * x + + +class T5v2DenseReluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + # TODO: Put in config + self.wi_0_activation = ACT2FN["gelu_new"] + self.wi_1_activation = ACT2FN["linear"] + + def forward(self, hidden_states): + # here check numerically + hidden_0 = self.wi_0(hidden_states) + hidden_0 = self.wi_0_activation(hidden_0) + hidden_1 = self.wi_1(hidden_states) + hidden_1 = self.wi_1_activation(hidden_1) + h = hidden_0 * hidden_1 + h = self.dropout(h) + h = self.wo(h) + return h + + +class T5v2LayerFF(nn.Module): + def __init__(self, config): + super().__init__() + self.DenseReluDense = T5v2DenseReluDense(config) + self.layer_norm = T5v2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + norm_x = self.layer_norm(hidden_states) + y = self.DenseReluDense(norm_x) + layer_output = hidden_states + self.dropout(y) + return layer_output + + +class T5v2Attention(nn.Module): + def __init__(self, config: T5Config, has_relative_attention_bias=False, is_bidirectional=False): + super().__init__() + self.is_bidirectional = is_bidirectional + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.d_model = config.d_model + self.d_kv = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.d_kv + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, self.d_kv, self.pruned_heads) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.d_kv * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + ret = 0 + n = -relative_position + if bidirectional: + num_buckets //= 2 + ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets + n = torch.abs(n) + else: + n = torch.max(n, torch.zeros_like(n)) + # now n is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = n < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).to(torch.long) + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def compute_bias(self, qlen, klen): + """ Compute binned relative position bias """ + context_position = torch.arange(qlen, dtype=torch.long)[:, None] + memory_position = torch.arange(klen, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (qlen, klen) + rp_bucket = self._relative_position_bucket( + relative_position, # shape (qlen, klen) + bidirectional=self.is_bidirectional, + num_buckets=self.relative_attention_num_buckets, + ) + rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) + return values + + def forward( + self, + input, + mask=None, + kv=None, + position_bias=None, + past_key_value=None, + head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if kv is None) or attention over source sentence (provided by kv). + """ + # Input is (bs, qlen, dim) + # Mask is (bs, klen) (non-causal) or (bs, klen, klen) + # past_key_value[0] is (bs, n_heads, q_len - 1, dim_per_head) + bs, qlen, dim = input.size() + + if past_key_value is not None: + assert self.is_decoder is True, "Encoder cannot cache past key value states" + assert ( + len(past_key_value) == 2 + ), "past_key_value should have 2 past states: keys and values. Got {} past states".format( + len(past_key_value) + ) + real_qlen = qlen + past_key_value[0].shape[2] if query_length is None else query_length + else: + real_qlen = qlen + + if kv is None: + klen = real_qlen + else: + klen = kv.size(1) + + def shape(x): + """ projection """ + return x.view(bs, -1, self.n_heads, self.d_kv).transpose(1, 2) + + def unshape(x): + """ compute context """ + return x.transpose(1, 2).contiguous().view(bs, -1, self.inner_dim) + + q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head) + + if kv is None: + k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head) + v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head) + elif past_key_value is None: + k = v = kv + k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head) + v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head) + + if past_key_value is not None: + if kv is None: + k_, v_ = past_key_value + k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head) + v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head) + else: + k, v = past_key_value + + if self.is_decoder and use_cache is True: + present_key_value_state = (k, v) + else: + present_key_value_state = None + + # (bs, n_heads, qlen, klen) + scores = torch.matmul( + q, k.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", q, k), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_qlen, klen), device=scores.device, dtype=scores.dtype + ) + else: + position_bias = self.compute_bias(real_qlen, klen) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -qlen:, :] + + if mask is not None: + position_bias = position_bias + mask # (bs, n_heads, qlen, klen) + + scores += position_bias + weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen) + weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen) + + # Mask heads if we want to + if head_mask is not None: + weights = weights * head_mask + + context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) + context = unshape(context) # (bs, qlen, dim) + + context = self.o(context) + + outputs = (context,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (weights,) + + return outputs + + +class T5v2LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5v2Attention( + config, has_relative_attention_bias=has_relative_attention_bias, is_bidirectional=not config.is_decoder + ) + self.layer_norm = T5v2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + norm_x = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + norm_x, + mask=attention_mask, + position_bias=position_bias, + head_mask=head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + y = attention_output[0] + layer_output = hidden_states + self.dropout(y) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5v2LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = T5v2Attention(config, has_relative_attention_bias=False, is_bidirectional=True) + self.layer_norm = T5v2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + kv, + attention_mask=None, + position_bias=None, + head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + norm_x = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + norm_x, + mask=attention_mask, + kv=kv, + position_bias=position_bias, + head_mask=head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + y = attention_output[0] + layer_output = hidden_states + self.dropout(y) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5v2Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append(T5v2LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + if self.is_decoder: + self.layer.append(T5v2LayerCrossAttention(config)) + + self.layer.append(T5v2LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=False, + ): + + if past_key_value is not None: + assert self.is_decoder, "Only decoder can use `past_key_values`" + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format( + expected_num_past_key_values, + "2 (past / key) for cross attention" if expected_num_past_key_values == 4 else "", + len(past_key_value), + ) + assert len(past_key_value) == expected_num_past_key_values, error_message + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + head_mask=head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + kv=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + head_mask=head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + outputs = (hidden_states,) + + outputs = outputs + (present_key_value_state,) + attention_outputs + return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + + +class T5v2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """ Initialize the weights """ + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, T5v2LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, (T5v2Model, T5v2ForConditionalGeneration)): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, T5v2DenseReluDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5v2Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + d_kv = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * d_kv) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * d_kv) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert ( + decoder_start_token_id is not None + ), "self.model.config.decoder_start_token_id has to be defined. In T5v2 it is usually set to the pad_token_id. See T5v2 docs for more information" + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +class T5v2Stack(T5v2PreTrainedModel): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [T5v2Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm = T5v2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + self.init_weights() + + def get_input_embeddings(self): + return self.embed_tokens + + def get_output_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + assert self.is_decoder, ":obj:`use_cache` can only be set to `True` if {} is used as a decoder".format( + self + ) + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) + + if self.is_decoder and encoder_attention_mask is not None: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + head_mask=head_mask[i], + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, position_bias, key-value-states (self-attention weights), (self-attention), cross-attention position_bias, (cross-attention weights) + + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +T5v2_START_DOCSTRING = r""" + + The T5v2 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer + `__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, + Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text + denoising generative setting. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +T5v2_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5v2 is a model with relative position embeddings so + you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5v2Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5v2 Training + <./t5.html#training>`__. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Provide for sequence to sequence training. T5v2 uses the :obj:`pad_token_id` as the starting token for + :obj:`decoder_input_ids` generation. If :obj:`past_key_values` is used, optionally only the last + :obj:`decoder_input_ids` have to be input (see :obj:`past_key_values`). + + To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5v2 Training + <./t5.html#training>`__. If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, + :obj:`decoder_input_ids` takes the value of :obj:`input_ids`. + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): + Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: + `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a + sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded + representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` + have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert + :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` + takes the value of :obj:`inputs_embeds`. + + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare T5v2 Model transformer outputting raw hidden-states" "without any specific head on top.", + T5v2_START_DOCSTRING, +) +class T5v2Model(T5v2PreTrainedModel): + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5v2Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5v2Stack(decoder_config, self.shared) + + self.init_weights() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5v2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + head_mask=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + r""" + Returns: + + Example:: + + >>> from transformers import T5v2Tokenizer, T5v2Model + + >>> tokenizer = T5v2Tokenizer.from_pretrained('t5-small') + >>> model = T5v2Model.from_pretrained('t5-small') + + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, return_dict=True) + + >>> last_hidden_states = outputs.last_hidden_state + """ + if "decoder_past_key_value_states" in kwargs: + warnings.warn( + "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = kwargs.pop("decoder_past_key_value_states") + if "decoder_past_key_values" in kwargs: + warnings.warn( + "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = kwargs.pop("decoder_past_key_values") + assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""T5v2 Model with a `language modeling` head on top. """, T5v2_START_DOCSTRING) +class T5v2ForConditionalGeneration(T5v2PreTrainedModel): + authorized_missing_keys = [r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight"] + + def __init__(self, config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5v2Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5v2Stack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + self.init_weights() + + def get_input_embeddings(self): + return self.shared + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(T5v2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + head_mask=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., + config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for + labels in ``[0, ..., config.vocab_size]`` + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + + Returns: + + Examples:: + + >>> from transformers import T5v2Tokenizer, T5v2ForConditionalGeneration + + >>> tokenizer = T5v2Tokenizer.from_pretrained('t5-small') + >>> model = T5v2ForConditionalGeneration.from_pretrained('t5-small', return_dict=True) + + >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids + >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + """ + + if "lm_labels" in kwargs: + warnings.warn( + "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("lm_labels") + if "decoder_past_key_value_states" in kwargs: + warnings.warn( + "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = kwargs.pop("decoder_past_key_value_states") + if "decoder_past_key_values" in kwargs: + warnings.warn( + "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = kwargs.pop("decoder_past_key_values") + assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # If decoding with past key value states, only the last tokens + # should be given as an input + if past_key_values is not None: + assert labels is None, "Decoder should not use cached key value states when training." + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_inputs_embeds is not None: + decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "use_cache": use_cache, + } + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past