diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index 5b7a333f47c2..eaed8ba6c9a0 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -212,40 +212,48 @@ def __init__(self, config, **kwargs): ) self.num_attention_heads = config.num_attention_heads + assert config.hidden_size % config.num_attention_heads == 0 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = tf.keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + self.query = tf.keras.layers.experimental.EinsumDense( + equation="abc,cde->abde", + output_shape=(None, config.num_attention_heads, self.attention_head_size), + bias_axes="de", + kernel_initializer=get_initializer(config.initializer_range), + name="query", ) - self.key = tf.keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + self.key = tf.keras.layers.experimental.EinsumDense( + equation="abc,cde->abde", + output_shape=(None, config.num_attention_heads, self.attention_head_size), + bias_axes="de", + kernel_initializer=get_initializer(config.initializer_range), + name="key", ) - self.value = tf.keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + self.value = tf.keras.layers.experimental.EinsumDense( + equation="abc,cde->abde", + output_shape=(None, config.num_attention_heads, self.attention_head_size), + bias_axes="de", + kernel_initializer=get_initializer(config.initializer_range), + name="value", ) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x, batch_size): - x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) + def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False): + query_tensor = self.query(hidden_states) - return tf.transpose(x, perm=[0, 2, 1, 3]) + # `key_tensor` = [B, S, N, H] + key_tensor = self.key(hidden_states) - def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False): - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = tf.matmul( - query_layer, key_layer, transpose_b=True - ) # (batch size, num_heads, seq_len_q, seq_len_k) - dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores - attention_scores = attention_scores / tf.math.sqrt(dk) + # `value_tensor` = [B, S, N, H] + value_tensor = self.value(hidden_states) + + # Take the dot product between "query" and "key" to get the raw + # attention scores. + attention_scores = tf.einsum("BSNH,BTNH->BNTS", key_tensor, query_tensor) + dk = tf.cast(self.attention_head_size, dtype=attention_scores.dtype) # scale attention_scores + attention_scores = tf.multiply(attention_scores, tf.math.rsqrt(dk)) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in TFBertModel call() function) @@ -262,12 +270,8 @@ def call(self, hidden_states, attention_mask, head_mask, output_attentions, trai if head_mask is not None: attention_probs = attention_probs * head_mask - context_layer = tf.matmul(attention_probs, value_layer) - context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) - context_layer = tf.reshape( - context_layer, (batch_size, -1, self.all_head_size) - ) # (batch_size, seq_len_q, all_head_size) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + attention_output = tf.einsum("BNTS,BSNH->BTNH", attention_probs, value_tensor) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) return outputs @@ -276,8 +280,18 @@ class TFBertSelfOutput(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.dense = tf.keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + self.num_attention_heads = config.num_attention_heads + + assert config.hidden_size % config.num_attention_heads == 0 + + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dense = tf.keras.layers.experimental.EinsumDense( + equation="abcd,cde->abe", + output_shape=(None, self.all_head_size), + bias_axes="e", + kernel_initializer=get_initializer(config.initializer_range), + name="dense", ) self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) @@ -314,8 +328,12 @@ class TFBertIntermediate(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.dense = tf.keras.layers.Dense( - config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + self.dense = tf.keras.layers.experimental.EinsumDense( + equation="abc,cd->abd", + output_shape=(None, config.intermediate_size), + bias_axes="d", + kernel_initializer=get_initializer(config.initializer_range), + name="dense", ) if isinstance(config.hidden_act, str): @@ -334,8 +352,12 @@ class TFBertOutput(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.dense = tf.keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + self.dense = tf.keras.layers.experimental.EinsumDense( + equation="abc,cd->abd", + bias_axes="d", + output_shape=(None, config.hidden_size), + kernel_initializer=get_initializer(config.initializer_range), + name="dense", ) self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/modeling_tf_electra.py b/src/transformers/modeling_tf_electra.py index c680042a8aec..58cb5219dab2 100644 --- a/src/transformers/modeling_tf_electra.py +++ b/src/transformers/modeling_tf_electra.py @@ -66,40 +66,48 @@ def __init__(self, config, **kwargs): ) self.num_attention_heads = config.num_attention_heads + assert config.hidden_size % config.num_attention_heads == 0 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = tf.keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + self.query = tf.keras.layers.experimental.EinsumDense( + equation="abc,cde->abde", + output_shape=(None, config.num_attention_heads, self.attention_head_size), + bias_axes="de", + kernel_initializer=get_initializer(config.initializer_range), + name="query", ) - self.key = tf.keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + self.key = tf.keras.layers.experimental.EinsumDense( + equation="abc,cde->abde", + output_shape=(None, config.num_attention_heads, self.attention_head_size), + bias_axes="de", + kernel_initializer=get_initializer(config.initializer_range), + name="key", ) - self.value = tf.keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + self.value = tf.keras.layers.experimental.EinsumDense( + equation="abc,cde->abde", + output_shape=(None, config.num_attention_heads, self.attention_head_size), + bias_axes="de", + kernel_initializer=get_initializer(config.initializer_range), + name="value", ) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x, batch_size): - x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) + def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False): + query_tensor = self.query(hidden_states) - return tf.transpose(x, perm=[0, 2, 1, 3]) + # `key_tensor` = [B, S, N, H] + key_tensor = self.key(hidden_states) - def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False): - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = tf.matmul( - query_layer, key_layer, transpose_b=True - ) # (batch size, num_heads, seq_len_q, seq_len_k) - dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores - attention_scores = attention_scores / tf.math.sqrt(dk) + # `value_tensor` = [B, S, N, H] + value_tensor = self.value(hidden_states) + + # Take the dot product between "query" and "key" to get the raw + # attention scores. + attention_scores = tf.einsum("BSNH,BTNH->BNTS", key_tensor, query_tensor) + dk = tf.cast(self.attention_head_size, dtype=attention_scores.dtype) # scale attention_scores + attention_scores = tf.multiply(attention_scores, tf.math.rsqrt(dk)) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in TFBertModel call() function) @@ -116,12 +124,8 @@ def call(self, hidden_states, attention_mask, head_mask, output_attentions, trai if head_mask is not None: attention_probs = attention_probs * head_mask - context_layer = tf.matmul(attention_probs, value_layer) - context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) - context_layer = tf.reshape( - context_layer, (batch_size, -1, self.all_head_size) - ) # (batch_size, seq_len_q, all_head_size) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + attention_output = tf.einsum("BNTS,BSNH->BTNH", attention_probs, value_tensor) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) return outputs @@ -131,8 +135,18 @@ class TFElectraSelfOutput(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.dense = tf.keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + self.num_attention_heads = config.num_attention_heads + + assert config.hidden_size % config.num_attention_heads == 0 + + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dense = tf.keras.layers.experimental.EinsumDense( + equation="abcd,cde->abe", + output_shape=(None, self.all_head_size), + bias_axes="e", + kernel_initializer=get_initializer(config.initializer_range), + name="dense", ) self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) @@ -171,8 +185,12 @@ class TFElectraIntermediate(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.dense = tf.keras.layers.Dense( - config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + self.dense = tf.keras.layers.experimental.EinsumDense( + equation="abc,cd->abd", + output_shape=(None, config.intermediate_size), + bias_axes="d", + kernel_initializer=get_initializer(config.initializer_range), + name="dense", ) if isinstance(config.hidden_act, str): @@ -192,8 +210,12 @@ class TFElectraOutput(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.dense = tf.keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + self.dense = tf.keras.layers.experimental.EinsumDense( + equation="abc,cd->abd", + bias_axes="d", + output_shape=(None, config.hidden_size), + kernel_initializer=get_initializer(config.initializer_range), + name="dense", ) self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/modeling_tf_longformer.py b/src/transformers/modeling_tf_longformer.py index 1918a21022c7..24b4d685641e 100644 --- a/src/transformers/modeling_tf_longformer.py +++ b/src/transformers/modeling_tf_longformer.py @@ -270,8 +270,12 @@ class TFLongformerIntermediate(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.dense = tf.keras.layers.Dense( - config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + self.dense = tf.keras.layers.experimental.EinsumDense( + equation="abc,cd->abd", + output_shape=(None, config.intermediate_size), + bias_axes="d", + kernel_initializer=get_initializer(config.initializer_range), + name="dense", ) if isinstance(config.hidden_act, str): @@ -291,8 +295,12 @@ class TFLongformerOutput(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.dense = tf.keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + self.dense = tf.keras.layers.experimental.EinsumDense( + equation="abc,cd->abd", + bias_axes="d", + output_shape=(None, config.hidden_size), + kernel_initializer=get_initializer(config.initializer_range), + name="dense", ) self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) @@ -326,7 +334,6 @@ def call(self, hidden_states): return pooled_output -# Copied from transformers.modeling_tf_bert.TFBertSelfOutput class TFLongformerSelfOutput(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/modeling_tf_roberta.py b/src/transformers/modeling_tf_roberta.py index 54334bdccbd4..b7e965f9a278 100644 --- a/src/transformers/modeling_tf_roberta.py +++ b/src/transformers/modeling_tf_roberta.py @@ -243,40 +243,48 @@ def __init__(self, config, **kwargs): ) self.num_attention_heads = config.num_attention_heads + assert config.hidden_size % config.num_attention_heads == 0 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = tf.keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + self.query = tf.keras.layers.experimental.EinsumDense( + equation="abc,cde->abde", + output_shape=(None, config.num_attention_heads, self.attention_head_size), + bias_axes="de", + kernel_initializer=get_initializer(config.initializer_range), + name="query", ) - self.key = tf.keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + self.key = tf.keras.layers.experimental.EinsumDense( + equation="abc,cde->abde", + output_shape=(None, config.num_attention_heads, self.attention_head_size), + bias_axes="de", + kernel_initializer=get_initializer(config.initializer_range), + name="key", ) - self.value = tf.keras.layers.Dense( - self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + self.value = tf.keras.layers.experimental.EinsumDense( + equation="abc,cde->abde", + output_shape=(None, config.num_attention_heads, self.attention_head_size), + bias_axes="de", + kernel_initializer=get_initializer(config.initializer_range), + name="value", ) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x, batch_size): - x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) + def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False): + query_tensor = self.query(hidden_states) - return tf.transpose(x, perm=[0, 2, 1, 3]) + # `key_tensor` = [B, S, N, H] + key_tensor = self.key(hidden_states) - def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False): - batch_size = shape_list(hidden_states)[0] - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) - key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) - value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = tf.matmul( - query_layer, key_layer, transpose_b=True - ) # (batch size, num_heads, seq_len_q, seq_len_k) - dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores - attention_scores = attention_scores / tf.math.sqrt(dk) + # `value_tensor` = [B, S, N, H] + value_tensor = self.value(hidden_states) + + # Take the dot product between "query" and "key" to get the raw + # attention scores. + attention_scores = tf.einsum("BSNH,BTNH->BNTS", key_tensor, query_tensor) + dk = tf.cast(self.attention_head_size, dtype=attention_scores.dtype) # scale attention_scores + attention_scores = tf.multiply(attention_scores, tf.math.rsqrt(dk)) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in TFBertModel call() function) @@ -293,12 +301,8 @@ def call(self, hidden_states, attention_mask, head_mask, output_attentions, trai if head_mask is not None: attention_probs = attention_probs * head_mask - context_layer = tf.matmul(attention_probs, value_layer) - context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) - context_layer = tf.reshape( - context_layer, (batch_size, -1, self.all_head_size) - ) # (batch_size, seq_len_q, all_head_size) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + attention_output = tf.einsum("BNTS,BSNH->BTNH", attention_probs, value_tensor) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) return outputs @@ -308,8 +312,18 @@ class TFRobertaSelfOutput(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.dense = tf.keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + self.num_attention_heads = config.num_attention_heads + + assert config.hidden_size % config.num_attention_heads == 0 + + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dense = tf.keras.layers.experimental.EinsumDense( + equation="abcd,cde->abe", + output_shape=(None, self.all_head_size), + bias_axes="e", + kernel_initializer=get_initializer(config.initializer_range), + name="dense", ) self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) @@ -348,8 +362,12 @@ class TFRobertaIntermediate(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.dense = tf.keras.layers.Dense( - config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + self.dense = tf.keras.layers.experimental.EinsumDense( + equation="abc,cd->abd", + output_shape=(None, config.intermediate_size), + bias_axes="d", + kernel_initializer=get_initializer(config.initializer_range), + name="dense", ) if isinstance(config.hidden_act, str): @@ -369,8 +387,12 @@ class TFRobertaOutput(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.dense = tf.keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + self.dense = tf.keras.layers.experimental.EinsumDense( + equation="abc,cd->abd", + bias_axes="d", + output_shape=(None, config.hidden_size), + kernel_initializer=get_initializer(config.initializer_range), + name="dense", ) self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)