diff --git a/src/transformers/configuration_t5.py b/src/transformers/configuration_t5.py index 6c9b9d692953..7c9e3f38d795 100644 --- a/src/transformers/configuration_t5.py +++ b/src/transformers/configuration_t5.py @@ -75,7 +75,6 @@ class T5Config(PretrainedConfig): def __init__( self, vocab_size=32128, - n_positions=512, d_model=512, d_kv=64, d_ff=2048, @@ -98,7 +97,6 @@ def __init__( **kwargs, ) self.vocab_size = vocab_size - self.n_positions = n_positions self.d_model = d_model self.d_kv = d_kv self.d_ff = d_ff @@ -112,10 +110,6 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.initializer_factor = initializer_factor - @property - def max_position_embeddings(self): - return self.n_positions - @property def hidden_size(self): return self.d_model diff --git a/src/transformers/convert_t5_original_tf_checkpoint_to_pytorch.py b/src/transformers/convert_t5_original_tf_checkpoint_to_pytorch.py index 67b65faba97d..aed3c1e5e25f 100755 --- a/src/transformers/convert_t5_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/convert_t5_original_tf_checkpoint_to_pytorch.py @@ -17,8 +17,6 @@ import argparse -import torch - from transformers import T5Config, T5Model, load_tf_weights_in_t5 from transformers.utils import logging @@ -37,7 +35,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du # Save pytorch-model print("Save PyTorch model to {}".format(pytorch_dump_path)) - torch.save(model.state_dict(), pytorch_dump_path) + model.save_pretrained(pytorch_dump_path) if __name__ == "__main__": diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 51ea2560b393..88861b10a8c3 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -115,12 +115,12 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): scope_names = [m_name] if scope_names[0] in ["kernel", "scale", "embedding"]: pointer = getattr(pointer, "weight") - # elif scope_names[0] == 'scale': - # pointer = getattr(pointer, 'weight') - # elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': - # pointer = getattr(pointer, 'bias') - # elif scope_names[0] == 'squad': - # pointer = getattr(pointer, 'classifier') + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") else: try: pointer = getattr(pointer, scope_names[0]) @@ -147,7 +147,6 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): tf_weights.pop(txt_name, None) logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys()))) - # logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys()))) return model @@ -167,14 +166,15 @@ def __init__(self, hidden_size, eps=1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def forward(self, x): + def forward(self, hidden_states): # layer norm should always be calculated in float32 - variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) - x = x / torch.sqrt(variance + self.variance_epsilon) + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # convert into float16 if necessary if self.weight.dtype == torch.float16: - x = x.to(torch.float16) - return self.weight * x + hidden_states = hidden_states.to(torch.float16) + return self.weight * hidden_states class T5DenseReluDense(nn.Module): @@ -185,11 +185,11 @@ def __init__(self, config): self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states): - h = self.wi(hidden_states) - h = F.relu(h) - h = self.dropout(h) - h = self.wo(h) - return h + hidden_states = self.wi(hidden_states) + hidden_states = F.relu(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states class T5LayerFF(nn.Module): @@ -200,25 +200,24 @@ def __init__(self, config): self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states): - norm_x = self.layer_norm(hidden_states) - y = self.DenseReluDense(norm_x) - layer_output = hidden_states + self.dropout(y) - return layer_output + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states class T5Attention(nn.Module): - def __init__(self, config: T5Config, has_relative_attention_bias=False, is_bidirectional=False): + def __init__(self, config: T5Config, has_relative_attention_bias=False): super().__init__() - self.is_bidirectional = is_bidirectional self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias self.relative_attention_num_buckets = config.relative_attention_num_buckets self.d_model = config.d_model - self.d_kv = config.d_kv + self.key_value_proj_dim = config.d_kv self.n_heads = config.num_heads self.dropout = config.dropout_rate - self.inner_dim = self.n_heads * self.d_kv + self.inner_dim = self.n_heads * self.key_value_proj_dim # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -233,7 +232,9 @@ def __init__(self, config: T5Config, has_relative_attention_bias=False, is_bidir def prune_heads(self, heads): if len(heads) == 0: return - heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, self.d_kv, self.pruned_heads) + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) # Prune linear layers self.q = prune_linear_layer(self.q, index) self.k = prune_linear_layer(self.k, index) @@ -241,7 +242,7 @@ def prune_heads(self, heads): self.o = prune_linear_layer(self.o, index, dim=1) # Update hyper params self.n_heads = self.n_heads - len(heads) - self.inner_dim = self.d_kv * self.n_heads + self.inner_dim = self.key_value_proj_dim * self.n_heads self.pruned_heads = self.pruned_heads.union(heads) @staticmethod @@ -266,49 +267,52 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ - ret = 0 - n = -relative_position + relative_buckets = 0 if bidirectional: num_buckets //= 2 - ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets - n = torch.abs(n) + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) else: - n = torch.max(n, torch.zeros_like(n)) - # now n is in the range [0, inf) + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions max_exact = num_buckets // 2 - is_small = n < max_exact + is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - val_if_large = max_exact + ( - torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) ).to(torch.long) - val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) - ret += torch.where(is_small, n, val_if_large) - return ret + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets - def compute_bias(self, qlen, klen): + def compute_bias(self, query_length, key_length): """ Compute binned relative position bias """ - context_position = torch.arange(qlen, dtype=torch.long)[:, None] - memory_position = torch.arange(klen, dtype=torch.long)[None, :] - relative_position = memory_position - context_position # shape (qlen, klen) - rp_bucket = self._relative_position_bucket( - relative_position, # shape (qlen, klen) - bidirectional=self.is_bidirectional, + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), num_buckets=self.relative_attention_num_buckets, ) - rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) - values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) - values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values def forward( self, - input, + hidden_states, mask=None, - kv=None, + key_value_states=None, position_bias=None, past_key_value=None, head_mask=None, @@ -317,106 +321,113 @@ def forward( output_attentions=False, ): """ - Self-attention (if kv is None) or attention over source sentence (provided by kv). + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ - # Input is (bs, qlen, dim) - # Mask is (bs, klen) (non-causal) or (bs, klen, klen) - # past_key_value[0] is (bs, n_heads, q_len - 1, dim_per_head) - bs, qlen, dim = input.size() + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length if past_key_value is not None: - assert self.is_decoder is True, "Encoder cannot cache past key value states" assert ( len(past_key_value) == 2 ), "past_key_value should have 2 past states: keys and values. Got {} past states".format( len(past_key_value) ) - real_qlen = qlen + past_key_value[0].shape[2] if query_length is None else query_length - else: - real_qlen = qlen + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - if kv is None: - klen = real_qlen - else: - klen = kv.size(1) + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - def shape(x): + def shape(states): """ projection """ - return x.view(bs, -1, self.n_heads, self.d_kv).transpose(1, 2) - - def unshape(x): - """ compute context """ - return x.transpose(1, 2).contiguous().view(bs, -1, self.inner_dim) - - q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head) - - if kv is None: - k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head) - elif past_key_value is None: - k = v = kv - k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head) - - if past_key_value is not None: - if kv is None: - k_, v_ = past_key_value - k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head) - v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head) - else: - k, v = past_key_value + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """ reshape """ + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """ projects hidden states correctly to key/query states """ + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) - if self.is_decoder and use_cache is True: - present_key_value_state = ((k, v),) - else: - present_key_value_state = (None,) + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) - # (bs, n_heads, qlen, klen) + # compute scores scores = torch.matmul( - q, k.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", q, k), compatible with onnx op>9 + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 if position_bias is None: if not self.has_relative_attention_bias: - raise ValueError("No position_bias provided and no weights to compute position_bias") - position_bias = self.compute_bias(real_qlen, klen) + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + else: + position_bias = self.compute_bias(real_seq_length, key_length) # if key and values are already calculated # we want only the last query position bias if past_key_value is not None: - position_bias = position_bias[:, :, -qlen:, :] + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (bs, n_heads, qlen, klen) + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) scores += position_bias - weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen) - weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen) + attn_weights = F.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = F.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) # Mask heads if we want to if head_mask is not None: - weights = weights * head_mask - - context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) - context = unshape(context) # (bs, qlen, dim) + attn_weights = attn_weights * head_mask - context = self.o(context) + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) - outputs = (context,) + present_key_value_state + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) if output_attentions: - outputs = outputs + (weights,) - if self.has_relative_attention_bias: - outputs = outputs + (position_bias,) + outputs = outputs + (attn_weights,) return outputs class T5LayerSelfAttention(nn.Module): def __init__(self, config, has_relative_attention_bias=False): super().__init__() - self.SelfAttention = T5Attention( - config, has_relative_attention_bias=has_relative_attention_bias, is_bidirectional=not config.is_decoder - ) + self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -430,9 +441,9 @@ def forward( use_cache=False, output_attentions=False, ): - norm_x = self.layer_norm(hidden_states) + normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( - norm_x, + normed_hidden_states, mask=attention_mask, position_bias=position_bias, head_mask=head_mask, @@ -440,25 +451,22 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - y = attention_output[0] - layer_output = hidden_states + self.dropout(y) - outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs class T5LayerCrossAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config): super().__init__() - self.EncDecAttention = T5Attention( - config, has_relative_attention_bias=has_relative_attention_bias, is_bidirectional=True - ) + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward( self, hidden_states, - kv, + key_value_states, attention_mask=None, position_bias=None, head_mask=None, @@ -467,11 +475,11 @@ def forward( query_length=None, output_attentions=False, ): - norm_x = self.layer_norm(hidden_states) + normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( - norm_x, + normed_hidden_states, mask=attention_mask, - kv=kv, + key_value_states=key_value_states, position_bias=position_bias, head_mask=head_mask, past_key_value=past_key_value, @@ -479,8 +487,7 @@ def forward( query_length=query_length, output_attentions=output_attentions, ) - y = attention_output[0] - layer_output = hidden_states + self.dropout(y) + layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them return outputs @@ -492,7 +499,7 @@ def __init__(self, config, has_relative_attention_bias=False): self.layer = nn.ModuleList() self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) if self.is_decoder: - self.layer.append(T5LayerCrossAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append(T5LayerCrossAttention(config)) self.layer.append(T5LayerFF(config)) @@ -550,7 +557,7 @@ def forward( cross_attention_outputs = self.layer[1]( hidden_states, - kv=encoder_hidden_states, + key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, head_mask=head_mask, @@ -619,12 +626,12 @@ def _init_weights(self, module): # Mesh TensorFlow attention initialization to avoid scaling before softmax # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model - d_kv = self.config.d_kv + key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * d_kv) ** -0.5)) + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * d_kv) ** -0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) @@ -775,20 +782,20 @@ def forward( # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) hidden_states, present_key_value_state = layer_outputs[:2] - if i == 0: - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) - position_bias = layer_outputs[3 if output_attentions else 2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3] + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention weights), + # (self-attention position bias), (cross-attention weights), (cross-attention position bias) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] # append next layer key value states if use_cache: present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: - all_attentions = all_attentions + (layer_outputs[2],) + all_attentions = all_attentions + (layer_outputs[3],) if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[4 if i == 0 else 3],) + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -920,6 +927,12 @@ def forward( T5_START_DOCSTRING, ) class T5Model(T5PreTrainedModel): + authorized_missing_keys = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + def __init__(self, config: T5Config): super().__init__(config) self.shared = nn.Embedding(config.vocab_size, config.d_model) @@ -1063,7 +1076,14 @@ def forward( @add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) class T5ForConditionalGeneration(T5PreTrainedModel): - authorized_missing_keys = [r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight"] + authorized_missing_keys = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + r"lm_head\.weight", + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] def __init__(self, config): super().__init__(config) diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index e09f5c733164..dcdfc91e1f98 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -81,10 +81,10 @@ def build(self, input_shape): self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones") super().build(input_shape) - def call(self, x): - variance = tf.math.reduce_mean(tf.math.square(x), axis=-1, keepdims=True) - x = x * tf.math.rsqrt(variance + self.variance_epsilon) - return self.weight * x + def call(self, hidden_states): + variance = tf.math.reduce_mean(tf.math.square(hidden_states), axis=-1, keepdims=True) + hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states class TFT5DenseReluDense(tf.keras.layers.Layer): @@ -96,11 +96,11 @@ def __init__(self, config, **kwargs): self.act = tf.keras.activations.relu def call(self, hidden_states, training=False): - h = self.wi(hidden_states) - h = self.act(h) - h = self.dropout(h, training=training) - h = self.wo(h) - return h + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.wo(hidden_states) + return hidden_states class TFT5LayerFF(tf.keras.layers.Layer): @@ -111,18 +111,17 @@ def __init__(self, config, **kwargs): self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def call(self, hidden_states, training=False): - norm_x = self.layer_norm(hidden_states) - y = self.DenseReluDense(norm_x, training=training) - layer_output = hidden_states + self.dropout(y, training=training) - return layer_output + normed_hidden_states = self.layer_norm(hidden_states) + dense_output = self.DenseReluDense(normed_hidden_states, training=training) + hidden_states = hidden_states + self.dropout(dense_output, training=training) + return hidden_states class TFT5Attention(tf.keras.layers.Layer): NEW_ID = itertools.count() - def __init__(self, config, has_relative_attention_bias=False, is_bidirectional=False, **kwargs): + def __init__(self, config, has_relative_attention_bias=False, **kwargs): super().__init__(**kwargs) - self.is_bidirectional = is_bidirectional self.layer_id = next(TFT5Attention.NEW_ID) self.is_decoder = config.is_decoder self.use_cache = config.use_cache @@ -131,9 +130,9 @@ def __init__(self, config, has_relative_attention_bias=False, is_bidirectional=F self.relative_attention_num_buckets = config.relative_attention_num_buckets self.d_model = config.d_model - self.d_kv = config.d_kv + self.key_value_proj_dim = config.d_kv self.n_heads = config.num_heads - self.inner_dim = self.n_heads * self.d_kv + self.inner_dim = self.n_heads * self.key_value_proj_dim # Mesh TensorFlow initialization to avoid scaling before softmax self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="q") @@ -175,46 +174,48 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ - ret = 0 - n = -relative_position + relative_buckets = 0 + # n = -relative_position if bidirectional: num_buckets //= 2 - ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets - n = tf.math.abs(n) + relative_buckets += tf.dtypes.cast(tf.math.greater(relative_position, 0), tf.int32) * num_buckets + relative_position = tf.math.abs(relative_position) else: - n = tf.math.maximum(n, 0) + relative_position = -tf.math.minimum(relative_position, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 - is_small = tf.math.less(n, max_exact) - val_if_large = max_exact + tf.dtypes.cast( - tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact) + is_small = tf.math.less(relative_position, max_exact) + relative_position_if_large = max_exact + tf.dtypes.cast( + tf.math.log(tf.dtypes.cast(relative_position, tf.float32) / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact), tf.int32, ) - val_if_large = tf.math.minimum(val_if_large, num_buckets - 1) - ret += tf.where(is_small, n, val_if_large) - return ret + relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1) + relative_buckets += tf.where(is_small, relative_position, relative_position_if_large) + return relative_buckets - def compute_bias(self, qlen, klen): + def compute_bias(self, query_length, key_length): """ Compute binned relative position bias """ - context_position = tf.range(qlen)[:, None] - memory_position = tf.range(klen)[None, :] - relative_position = memory_position - context_position # shape (qlen, klen) - rp_bucket = self._relative_position_bucket( + context_position = tf.range(query_length)[:, None] + memory_position = tf.range(key_length)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( relative_position, - bidirectional=self.is_bidirectional, + bidirectional=(not self.is_decoder), num_buckets=self.relative_attention_num_buckets, ) - values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) - values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = tf.expand_dims( + tf.transpose(values, [2, 0, 1]), axis=0 + ) # shape (1, num_heads, query_length, key_length) return values def call( self, - input, + hidden_states, mask=None, - kv=None, + key_value_states=None, position_bias=None, past_key_value=None, head_mask=None, @@ -224,95 +225,108 @@ def call( output_attentions=False, ): """ - Self-attention (if kv is None) or attention over source sentence (provided by kv). + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ - # Input is (bs, qlen, dim) - # Mask is (bs, klen) (non-causal) or (bs, klen, klen) - # past_key_value[0] is (bs, n_heads, q_len - 1, dim_per_head) - bs, qlen, dim = shape_list(input) + # Input is (batch_size, query_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = shape_list(hidden_states)[:2] + + real_seq_length = seq_length if past_key_value is not None: - assert self.is_decoder is True, "Encoder cannot cache past key value states" assert ( len(past_key_value) == 2 ), "past_key_value should have 2 past states: keys and values. Got {} past states".format( len(past_key_value) ) - real_qlen = qlen + shape_list(past_key_value[0])[2] if query_length is None else query_length - else: - real_qlen = qlen + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - if kv is None: - klen = real_qlen - else: - klen = shape_list(kv)[1] + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - def shape(x): + def shape(hidden_states): """ projection """ - return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, self.d_kv)), perm=(0, 2, 1, 3)) + return tf.transpose( + tf.reshape(hidden_states, (batch_size, -1, self.n_heads, self.key_value_proj_dim)), perm=(0, 2, 1, 3) + ) - def unshape(x): + def unshape(hidden_states): """ compute context """ - return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.inner_dim)) - - q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head) - - if kv is None: - k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head) - elif past_key_value is None: - k = v = kv - k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head) + return tf.reshape(tf.transpose(hidden_states, perm=(0, 2, 1, 3)), (batch_size, -1, self.inner_dim)) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """ projects hidden states correctly to key/query states """ + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) - if past_key_value is not None: - if kv is None: - k_, v_ = past_key_value - k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head) - v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head) - else: - k, v = past_key_value + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = tf.concat([past_key_value, hidden_states], axis=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, query_length, dim_per_head) + + # get key/value + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) # to cope with keras serialization if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True: - present_key_value_state = ((k, v),) + present_key_value_state = (key_states, value_states) else: - present_key_value_state = (None,) + present_key_value_state = None - scores = tf.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen) + scores = tf.einsum( + "bnqd,bnkd->bnqk", query_states, key_states + ) # (batch_size, n_heads, query_length, key_length) if position_bias is None: if not self.has_relative_attention_bias: - raise ValueError("No position_bias provided and no weights to compute position_bias") - position_bias = self.compute_bias(real_qlen, klen) + position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length), dtype=tf.float32) + else: + position_bias = self.compute_bias(real_seq_length, key_length) # if key and values are already calculated # we want only the last query position bias if past_key_value is not None: - position_bias = position_bias[:, :, -qlen:, :] + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (bs, n_heads, qlen, klen) + position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length) scores += position_bias - weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) - weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) + weights = tf.nn.softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length) + weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length) # Mask heads if we want to if head_mask is not None: weights = weights * head_mask - context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) - context = unshape(context) # (bs, qlen, dim) + attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head) - context = self.o(context) + attn_output = self.o(unshape(attn_output)) - outputs = (context,) + present_key_value_state + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) if output_attentions: outputs = outputs + (weights,) - if self.has_relative_attention_bias: - outputs = outputs + (position_bias,) + return outputs @@ -322,7 +336,6 @@ def __init__(self, config, has_relative_attention_bias=False, **kwargs): self.SelfAttention = TFT5Attention( config, has_relative_attention_bias=has_relative_attention_bias, - is_bidirectional=not config.is_decoder, name="SelfAttention", ) self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") @@ -339,9 +352,9 @@ def call( output_attentions=False, training=False, ): - norm_x = self.layer_norm(hidden_states) + normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( - norm_x, + normed_hidden_states, mask=attention_mask, position_bias=position_bias, head_mask=head_mask, @@ -350,19 +363,17 @@ def call( output_attentions=output_attentions, training=training, ) - y = attention_output[0] - layer_output = hidden_states + self.dropout(y, training=training) - outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + hidden_states = hidden_states + self.dropout(attention_output[0], training=training) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs class TFT5LayerCrossAttention(tf.keras.layers.Layer): - def __init__(self, config, has_relative_attention_bias=False, **kwargs): + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.EncDecAttention = TFT5Attention( config, - has_relative_attention_bias=has_relative_attention_bias, - is_bidirectional=True, + has_relative_attention_bias=False, name="EncDecAttention", ) self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") @@ -371,7 +382,7 @@ def __init__(self, config, has_relative_attention_bias=False, **kwargs): def call( self, hidden_states, - kv, + key_value_states, attention_mask=None, position_bias=None, head_mask=None, @@ -381,11 +392,11 @@ def call( output_attentions=False, training=False, ): - norm_x = self.layer_norm(hidden_states) + normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( - norm_x, + normed_hidden_states, mask=attention_mask, - kv=kv, + key_value_states=key_value_states, position_bias=position_bias, head_mask=head_mask, past_key_value=past_key_value, @@ -394,9 +405,8 @@ def call( output_attentions=output_attentions, training=training, ) - y = attention_output[0] - layer_output = hidden_states + self.dropout(y, training=training) - outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + hidden_states = hidden_states + self.dropout(attention_output[0], training=training) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs @@ -416,7 +426,6 @@ def __init__(self, config, has_relative_attention_bias=False, **kwargs): self.layer.append( TFT5LayerCrossAttention( config, - has_relative_attention_bias=has_relative_attention_bias, name="layer_._1", ) ) @@ -477,7 +486,7 @@ def call( cross_attention_outputs = self.layer[1]( hidden_states, - kv=encoder_hidden_states, + key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, head_mask=head_mask, @@ -731,17 +740,18 @@ def call( # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) hidden_states, present_key_value_state = layer_outputs[:2] - if i == 0: - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) - position_bias = layer_outputs[3 if output_attentions else 2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, past_key_values, (self-attention weights), + # (self-attention position bias), (cross-attention position bias), (cross-attention weights), + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] # append next layer key value states present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: - all_attentions = all_attentions + (layer_outputs[2],) + all_attentions = all_attentions + (layer_outputs[3],) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states, training=training) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4de6d3f62d6c..80e510fc35e4 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -461,7 +461,7 @@ def test_headmasking(self): inputs = self._prepare_for_class(inputs_dict, model_class).copy() inputs["head_mask"] = head_mask - outputs = model(**inputs) + outputs = model(**inputs, return_dict=True) # Test that we can get a gradient back for importance score computation output = sum(t.sum() for t in outputs[0]) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 77b8ee1cae9b..11e3ab0ad872 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -556,6 +556,32 @@ def model(self): def tokenizer(self): return T5Tokenizer.from_pretrained("t5-base") + @slow + def test_small_integration_test(self): + """ + For comparision run: + >>> import t5 # pip install t5==0.7.1 + >>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_t5_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = T5ForConditionalGeneration.from_pretrained("t5-small", return_dict=True).to(torch_device) + tokenizer = T5Tokenizer.from_pretrained("t5-small") + + input_ids = tokenizer("Hello there", return_tensors="pt").input_ids + labels = tokenizer("Hi I am", return_tensors="pt").input_ids + + loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -19.0845 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + @slow def test_summarization(self): model = self.model @@ -567,8 +593,8 @@ def test_summarization(self): ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.' expected_summaries = [ - 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video at the crash site . "one can hear cries of \'My God\' in several languages," one magazine says .', - "the Palestinians become the 123rd member of the international criminal court . the accession was marked by a ceremony at the Hague, where the court is based . as members of the court, Palestinians may be subject to counter-charges as well .", + 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says .', + "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .", "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .", 'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .', ] diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index f42e03ca1681..af6080de77da 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -311,8 +311,8 @@ def test_summarization(self): ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.' expected_summaries = [ - 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video at the crash site . "one can hear cries of \'My God\' in several languages," one magazine says .', - "the Palestinians become the 123rd member of the international criminal court . the accession was marked by a ceremony at the Hague, where the court is based . as members of the court, Palestinians may be subject to counter-charges as well .", + 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says .', + "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .", "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .", 'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .', ]