Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions keras_nlp/models/t5/t5_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from keras_nlp.models.t5.t5_layer_norm import T5LayerNorm
from keras_nlp.models.t5.t5_transformer_layer import T5TransformerLayer
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.tensor_utils import assert_tf_backend


@keras_nlp_export("keras_nlp.models.T5Backbone")
Expand Down Expand Up @@ -81,8 +80,6 @@ def __init__(
tie_embedding_weights=False,
**kwargs,
):
assert_tf_backend(self.__class__.__name__)

# Encoder inputs
encoder_token_ids = keras.Input(
shape=(None,), dtype="int32", name="encoder_token_ids"
Expand Down Expand Up @@ -121,7 +118,7 @@ def __init__(

position_bias = None
for i in range(num_layers):
x, position_bias = T5TransformerLayer(
output = T5TransformerLayer(
is_decoder=False,
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
Expand All @@ -138,6 +135,8 @@ def __init__(
position_bias=position_bias,
use_causal_mask=False,
)
if isinstance(output, tuple):
x, position_bias = output

x = T5LayerNorm(
epsilon=layer_norm_epsilon,
Expand All @@ -162,7 +161,7 @@ def __init__(

position_bias = None
for i in range(num_layers):
x, position_bias = T5TransformerLayer(
output = T5TransformerLayer(
is_decoder=True,
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
Expand All @@ -181,6 +180,8 @@ def __init__(
encoder_attention_mask=encoder_attention_mask,
use_causal_mask=True,
)
if isinstance(output, tuple):
x, position_bias = output

x = T5LayerNorm(
epsilon=layer_norm_epsilon,
Expand Down
1 change: 0 additions & 1 deletion keras_nlp/models/t5/t5_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from keras_nlp.tests.test_case import TestCase


@pytest.mark.tf_only
class T5BackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
Expand Down
9 changes: 3 additions & 6 deletions keras_nlp/models/t5/t5_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf

from keras_nlp.backend import keras
from keras_nlp.backend import ops


class T5LayerNorm(keras.layers.Layer):
Expand All @@ -31,8 +30,6 @@ def build(self, input_shape):
self.built = True

def call(self, hidden_states):
variance = tf.math.reduce_mean(
tf.math.square(hidden_states), axis=-1, keepdims=True
)
hidden_states = hidden_states * tf.math.rsqrt(variance + self.epsilon)
variance = ops.mean(ops.square(hidden_states), axis=-1, keepdims=True)
hidden_states = hidden_states * ops.rsqrt(variance + self.epsilon)
return self.weight * hidden_states
86 changes: 39 additions & 47 deletions keras_nlp/models/t5/t5_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_slice
import numpy as np

from keras_nlp.backend import keras


def shape_list(tensor):
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
from keras_nlp.backend import ops


class T5MultiHeadAttention(keras.layers.Layer):
Expand Down Expand Up @@ -123,39 +115,39 @@ def _relative_position_bucket(
if bidirectional:
num_buckets //= 2
relative_buckets += (
tf.cast(
tf.math.greater(relative_position, 0),
ops.cast(
ops.greater(relative_position, 0),
dtype=relative_position.dtype,
)
* num_buckets
)
relative_position = tf.math.abs(relative_position)
relative_position = ops.abs(relative_position)
else:
relative_position = -tf.math.minimum(relative_position, 0)
relative_position = -ops.minimum(relative_position, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = tf.math.less(relative_position, max_exact)
relative_position_if_large = max_exact + tf.cast(
tf.math.log(
tf.cast(relative_position, "float32")
/ tf.cast(max_exact, "float32")
is_small = ops.less(relative_position, max_exact)
relative_position_if_large = max_exact + ops.cast(
ops.log(
ops.cast(relative_position, "float32")
/ ops.cast(max_exact, "float32")
)
/ tf.math.log(max_distance / max_exact)
/ ops.cast(ops.log(max_distance / max_exact), "float32")
* (num_buckets - max_exact),
dtype=relative_position.dtype,
)
relative_position_if_large = tf.math.minimum(
relative_position_if_large = ops.minimum(
relative_position_if_large, num_buckets - 1
)
relative_buckets += tf.where(
relative_buckets += ops.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets

def compute_bias(self, query_length, key_length):
"""Compute binned relative position bias"""
context_position = tf.range(query_length)[:, None]
memory_position = tf.range(key_length)[None, :]
context_position = ops.arange(query_length)[:, None]
memory_position = ops.arange(key_length)[None, :]
relative_position = (
memory_position - context_position
) # shape (query_length, key_length)
Expand All @@ -165,11 +157,11 @@ def compute_bias(self, query_length, key_length):
num_buckets=self.relative_attention_buckets,
max_distance=self.relative_attention_max_distance,
)
values = tf.gather(
self.relative_attention_bias, relative_position_bucket
values = ops.take(
self.relative_attention_bias, relative_position_bucket, axis=0
) # shape (query_length, key_length, num_heads)
values = tf.expand_dims(
tf.transpose(values, [2, 0, 1]), axis=0
values = ops.expand_dims(
ops.transpose(values, axes=(2, 0, 1)), axis=0
) # shape (1, num_heads, query_length, key_length)
return values

Expand All @@ -186,7 +178,7 @@ def call(
):
# Input is (batch_size, query_length, dim)
# past_key_value[0] is (batch_size, num_heads, q_len - 1, dim_per_head)
batch_size, seq_length = shape_list(hidden_states)[:2]
batch_size, seq_length = ops.shape(hidden_states)[:2]

real_seq_length = seq_length

Expand All @@ -197,29 +189,29 @@ def call(
f"keys and values. Got {len(past_key_value)} past states."
)
real_seq_length += (
shape_list(past_key_value[0])[2]
ops.shape(past_key_value[0])[2]
if query_length is None
else query_length
)

key_length = (
real_seq_length
if key_value_states is None
else shape_list(key_value_states)[1]
else ops.shape(key_value_states)[1]
)

def shape(hidden_states):
return tf.transpose(
tf.reshape(
return ops.transpose(
ops.reshape(
hidden_states,
(batch_size, -1, self.num_heads, self.key_value_dim),
),
perm=(0, 2, 1, 3),
axes=(0, 2, 1, 3),
)

def unshape(hidden_states):
return tf.reshape(
tf.transpose(hidden_states, perm=(0, 2, 1, 3)),
return ops.reshape(
ops.transpose(hidden_states, axes=(0, 2, 1, 3)),
(batch_size, -1, self.inner_dim),
)

Expand All @@ -240,7 +232,7 @@ def project(
if key_value_states is None:
# self-attention
# (batch_size, num_heads, key_length, dim_per_head)
hidden_states = tf.concat(
hidden_states = ops.concat(
[past_key_value, hidden_states], axis=2
)
else:
Expand All @@ -267,13 +259,13 @@ def project(
past_key_value[1] if past_key_value is not None else None,
)

scores = tf.einsum(
scores = ops.einsum(
"bnqd,bnkd->bnqk", query_states, key_states
) # (batch_size, num_heads, query_length, key_length)

if position_bias is None:
if not self.use_relative_attention_bias:
position_bias = tf.zeros(
position_bias = ops.zeros(
(1, self.num_heads, real_seq_length, key_length),
self.compute_dtype,
)
Expand All @@ -289,24 +281,24 @@ def project(
# we might have a padded past structure,
# in which case we want to fetch the position bias slice
# right after the most recently filled past index
most_recently_filled_past_index = tf.reduce_max(
tf.where(past_key_value[0][0, 0, :, 0] != 0.0)
most_recently_filled_past_index = ops.amax(
ops.where(past_key_value[0][0, 0, :, 0] != 0.0)
)
position_bias = dynamic_slice(
position_bias = ops.slice(
position_bias,
(0, 0, most_recently_filled_past_index + 1, 0),
(1, self.num_heads, seq_length, real_seq_length),
)

if mask is not None:
# Add a new mask axis for the head dim.
mask = mask[:, tf.newaxis, :, :]
mask = mask[:, np.newaxis, :, :]
# Add a very large negative position bias for masked positions.
mask = (1.0 - tf.cast(mask, position_bias.dtype)) * -1e9
mask = (1.0 - ops.cast(mask, position_bias.dtype)) * -1e9
position_bias = position_bias + mask

scores += position_bias
weights = tf.nn.softmax(
weights = ops.nn.softmax(
scores, axis=-1
) # (batch_size, num_heads, query_length, key_length)
weights = self.dropout_layer(
Expand All @@ -315,9 +307,9 @@ def project(

# Optionally mask heads
if layer_head_mask is not None:
weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights
weights = ops.reshape(layer_head_mask, (1, -1, 1, 1)) * weights

attention_output = tf.matmul(
attention_output = ops.matmul(
weights, value_states
) # (batch_size, num_heads, query_length, dim_per_head)

Expand Down
12 changes: 7 additions & 5 deletions keras_nlp/models/t5/t5_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf

from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.layers.modeling.transformer_layer_utils import (
compute_causal_mask,
)
Expand Down Expand Up @@ -103,10 +102,10 @@ def call(
training=False,
):
if use_causal_mask:
shape = tf.shape(hidden_states)
shape = ops.shape(hidden_states)
batch_size, length = shape[0], shape[1]
causal_mask = compute_causal_mask(batch_size, length, length)
attention_mask = tf.cast(attention_mask, "int32")
attention_mask = ops.cast(attention_mask, "int32")
attention_mask = causal_mask & attention_mask

x = hidden_states # Intermediate result.
Expand Down Expand Up @@ -147,4 +146,7 @@ def call(
x = self.dropout_layer(x, training=training)
x = x + residual

return x, position_bias
if position_bias is not None:
return x, position_bias
else:
return x