Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions keras_nlp/models/falcon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
159 changes: 159 additions & 0 deletions keras_nlp/models/falcon/falcon_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

from keras_nlp.backend import keras
from keras_nlp.backend import ops


class FalconAttention(keras.layers.Layer):
def __init__(
self,
num_heads,
attention_dropout,
**kwargs,
):
super().__init__(**kwargs)
self.num_heads = num_heads
self.attention_dropout = attention_dropout

def build(self, inputs_shape):
batch_size, seq_length, hidden_dim = inputs_shape

self.head_dim = hidden_dim // self.num_heads

# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)

self._query_dense = keras.layers.EinsumDense(
equation="btm,mnh->btnh",
output_shape=(None, self.num_heads, self.head_dim),
bias_axes="nh",
dtype=self.dtype_policy,
name="query_dense",
)
self._query_dense.build(inputs_shape)

self._key_dense = keras.layers.EinsumDense(
equation="bsm,mnh->bsnh",
output_shape=(None, self.num_heads, self.head_dim),
bias_axes="nh",
dtype=self.dtype_policy,
name="key_dense",
)
self._key_dense.build(inputs_shape)

self._value_dense = keras.layers.EinsumDense(
equation="bsm,mnh->bsnh",
output_shape=(None, self.num_heads, self.head_dim),
bias_axes="nh",
dtype=self.dtype_policy,
name="value_dense",
)
self._value_dense.build(inputs_shape)

self._attention_dropout = keras.layers.Dropout(
rate=self.attention_dropout,
dtype=self.dtype_policy,
name="attention_dropout",
)

self._output_dense = keras.layers.Dense(
hidden_dim,
dtype=self.dtype_policy,
name="output_dense",
)
self._output_dense.build(inputs_shape)

self._softmax = keras.layers.Softmax(dtype="float32", name="softmax")

self.built = True

def call(
self,
inputs,
alibi,
attention_mask=None,
cache=None,
cache_update_index=None,
):
batch_size, seq_length, hidden_dim = ops.shape(inputs)

query = self._query_dense(inputs)
key = self._key_dense(inputs)
value = self._value_dense(inputs)

if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key)
value = ops.slice_update(value_cache, start, value)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)

# query (batch_size, num_heads, query_length, head_dim)
query = ops.transpose(query, [0, 2, 1, 3])
# value (batch_size, num_heads, kv_length, head_dim)
value = ops.transpose(value, [0, 2, 1, 3])
# key (batch_size, num_heads, head_dim, kv_length)
key = ops.transpose(key, [0, 2, 3, 1])

attention_scores = ops.matmul(query, key)
attention_scores = ops.add(attention_scores, alibi)
attention_scores = (
attention_scores * self.inv_norm_factor
) # [batch_size, num_heads, query_length, kv_length]
attention_scores = self._softmax(
attention_scores, ops.expand_dims(attention_mask, 1)
)
attention_scores = self._attention_dropout(attention_scores)
attention_output = ops.matmul(
attention_scores, value
) # [batch_size, num_heads, query_length, head_dim]

attention_output = ops.transpose(
attention_output, [0, 2, 1, 3]
) # [batch_size, query_length, num_heads, head_dim]
attention_output = ops.reshape(
attention_output,
[batch_size, seq_length, self.num_heads * self.head_dim],
) # [batch_size, query_length, hidden_dim]

attention_output = self._output_dense(attention_output)

if cache is not None:
return attention_output, cache

return attention_output

def get_config(self):
config = super().get_config()
config.update(
{
"num_heads": self.num_heads,
"attention_dropout": self.attention_dropout,
}
)
return config
164 changes: 164 additions & 0 deletions keras_nlp/models/falcon/falcon_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.falcon.falcon_transformer_decoder import (
FalconTransformerDecoder,
)


@keras_nlp_export("keras_nlp.models.FalconBackbone")
class FalconBackbone(Backbone):
"""The Falcon core architecure.

This network implements a Transformer-based decoder-only network,
[Falcon](https://arxiv.org/abs/2306.01116).

Args:
vocabulary_size: int. The size of the token vocabulary.
num_layers: int. The number of transformer layers.
num_attention_heads: int. The number of attention heads for each transformer.
The hidden size must be divisible by the number of attention heads.
hidden_dim: int. The dimensionality of the embeddings and hidden states.
intermediate_dim: int. The output dimension of the first Dense layer in
the MLP network of each transformer.
layer_norm_epsilon: float. Epsilon for the layer normalization layers in
the transformer decoder.
attention_dropout: float. Dropout probability for the attention.
feedforward_dropout: flaot. Dropout probability for the feedforward.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
float32 precision regardless of dtype.

Examples:
```python
input_data = {
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
}

# Pretrained Falcon decoder.
# TODO: Update the preset.
model = keras_nlp.models.FalconBackbone.from_preset("falcon_preset")
model(input_data)

# Randomly initialized Falcon decoder with a custom config.
model = keras_nlp.models.FalconBackbone(
vocabulary_size=10,
num_layers=2,
num_attention_heads=2,
hidden_dim=32,
intermediate_dim=32*4,
layer_norm_epsilon=1e-5,
attention_dropout=0,
feedforward_dropout=0,
dtype="float32",
)
model(input_data)
```
"""

def __init__(
self,
vocabulary_size,
num_layers,
num_attention_heads,
hidden_dim,
intermediate_dim,
layer_norm_epsilon=1e-5,
attention_dropout=0,
feedforward_dropout=0,
dtype=None,
**kwargs,
):
# === Layers ===
# Embed Tokens
token_embedding_layer = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
dtype=dtype,
name="token_embedding",
)

# Apply successive transformer decoder blocks.
transformer_layers = []
for i in range(num_layers):
layer = FalconTransformerDecoder(
num_attention_heads=num_attention_heads,
intermediate_dim=intermediate_dim,
attention_dropout=attention_dropout,
feedforward_dropout=feedforward_dropout,
dtype=dtype,
name=f"transformer_layer_{i}",
)
transformer_layers.append(layer)

final_layernorm = keras.layers.LayerNormalization(
epsilon=layer_norm_epsilon,
dtype=dtype,
name="final_layernorm",
)

# === Functional Model ===
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
padding_mask = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)
x = token_embedding_layer(token_ids)

for transformer_layer in transformer_layers:
x = transformer_layer(inputs=x, decoder_padding_mask=padding_mask)
sequence_output = final_layernorm(x)

super().__init__(
inputs={
"token_ids": token_ids,
"padding_mask": padding_mask,
},
outputs=sequence_output,
**kwargs,
)

# === Config ===
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.attention_dropout = attention_dropout
self.feedforward_dropout = feedforward_dropout
self.layer_norm_epsilon = layer_norm_epsilon

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"attention_dropout": self.attention_dropout,
"feedforward_dropout": self.feedforward_dropout,
"layer_norm_epsilon": self.layer_norm_epsilon,
}
)
return config

@property
def token_embedding(self):
return self.get_layer("token_embedding")
49 changes: 49 additions & 0 deletions keras_nlp/models/falcon/falcon_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from keras_nlp.backend import ops
from keras_nlp.models.falcon.falcon_backbone import FalconBackbone
from keras_nlp.tests.test_case import TestCase


class FalconBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"vocabulary_size": 10,
"num_layers": 2,
"num_attention_heads": 8,
"hidden_dim": 16,
"intermediate_dim": 32,
}
self.input_data = {
"token_ids": ops.ones((2, 5), dtype="int32"),
"padding_mask": ops.ones((2, 5), dtype="int32"),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=FalconBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 16),
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=FalconBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
Loading