Skip to content
94 changes: 58 additions & 36 deletions src/transformers/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,40 +212,48 @@ def __init__(self, config, **kwargs):
)

self.num_attention_heads = config.num_attention_heads

assert config.hidden_size % config.num_attention_heads == 0

self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
self.query = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="query",
)
self.key = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
self.key = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="key",
)
self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
self.value = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="value",
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)

def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
query_tensor = self.query(hidden_states)

return tf.transpose(x, perm=[0, 2, 1, 3])
# `key_tensor` = [B, S, N, H]
key_tensor = self.key(hidden_states)

def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = tf.matmul(
query_layer, key_layer, transpose_b=True
) # (batch size, num_heads, seq_len_q, seq_len_k)
dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores
attention_scores = attention_scores / tf.math.sqrt(dk)
# `value_tensor` = [B, S, N, H]
value_tensor = self.value(hidden_states)

# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BSNH,BTNH->BNTS", key_tensor, query_tensor)
dk = tf.cast(self.attention_head_size, dtype=attention_scores.dtype) # scale attention_scores
attention_scores = tf.multiply(attention_scores, tf.math.rsqrt(dk))

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
Expand All @@ -262,12 +270,8 @@ def call(self, hidden_states, attention_mask, head_mask, output_attentions, trai
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
context_layer = tf.reshape(
context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
attention_output = tf.einsum("BNTS,BSNH->BTNH", attention_probs, value_tensor)
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)

return outputs

Expand All @@ -276,8 +280,18 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.num_attention_heads = config.num_attention_heads

assert config.hidden_size % config.num_attention_heads == 0

self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abcd,cde->abe",
output_shape=(None, self.all_head_size),
bias_axes="e",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
Expand Down Expand Up @@ -314,8 +328,12 @@ class TFBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)

if isinstance(config.hidden_act, str):
Expand All @@ -334,8 +352,12 @@ class TFBertOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
bias_axes="d",
output_shape=(None, config.hidden_size),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
Expand Down
94 changes: 58 additions & 36 deletions src/transformers/modeling_tf_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,40 +66,48 @@ def __init__(self, config, **kwargs):
)

self.num_attention_heads = config.num_attention_heads

assert config.hidden_size % config.num_attention_heads == 0

self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
self.query = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="query",
)
self.key = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
self.key = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="key",
)
self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
self.value = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="value",
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)

def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
query_tensor = self.query(hidden_states)

return tf.transpose(x, perm=[0, 2, 1, 3])
# `key_tensor` = [B, S, N, H]
key_tensor = self.key(hidden_states)

def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = tf.matmul(
query_layer, key_layer, transpose_b=True
) # (batch size, num_heads, seq_len_q, seq_len_k)
dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores
attention_scores = attention_scores / tf.math.sqrt(dk)
# `value_tensor` = [B, S, N, H]
value_tensor = self.value(hidden_states)

# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BSNH,BTNH->BNTS", key_tensor, query_tensor)
dk = tf.cast(self.attention_head_size, dtype=attention_scores.dtype) # scale attention_scores
attention_scores = tf.multiply(attention_scores, tf.math.rsqrt(dk))

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
Expand All @@ -116,12 +124,8 @@ def call(self, hidden_states, attention_mask, head_mask, output_attentions, trai
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
context_layer = tf.reshape(
context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
attention_output = tf.einsum("BNTS,BSNH->BTNH", attention_probs, value_tensor)
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)

return outputs

Expand All @@ -131,8 +135,18 @@ class TFElectraSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.num_attention_heads = config.num_attention_heads

assert config.hidden_size % config.num_attention_heads == 0

self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abcd,cde->abe",
output_shape=(None, self.all_head_size),
bias_axes="e",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
Expand Down Expand Up @@ -171,8 +185,12 @@ class TFElectraIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)

if isinstance(config.hidden_act, str):
Expand All @@ -192,8 +210,12 @@ class TFElectraOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
bias_axes="d",
output_shape=(None, config.hidden_size),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
Expand Down
17 changes: 12 additions & 5 deletions src/transformers/modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,12 @@ class TFLongformerIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)

if isinstance(config.hidden_act, str):
Expand All @@ -291,8 +295,12 @@ class TFLongformerOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
bias_axes="d",
output_shape=(None, config.hidden_size),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
Expand Down Expand Up @@ -326,7 +334,6 @@ def call(self, hidden_states):
return pooled_output


# Copied from transformers.modeling_tf_bert.TFBertSelfOutput
class TFLongformerSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
Expand Down
Loading