-
Notifications
You must be signed in to change notification settings - Fork 305
Add Bloom Model #1382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add Bloom Model #1382
Changes from all commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
438bf01
Add Bloom Model
abuelnasr0 bb281d9
Add Backbone test and some fixes
abuelnasr0 762f1a0
Add BloomBackbone to keras_nlp.models
abuelnasr0 a01fcd8
Fix a typo in layer naming
abuelnasr0 889f204
Remove self.built = True
abuelnasr0 7e0313f
Revert "Remove self.built = True"
abuelnasr0 6eba2fa
Add built=True to MLP layer
abuelnasr0 b77a22e
Add Checkpoint conversion script
abuelnasr0 a61267f
Change LayerNorm name
abuelnasr0 f64f532
Fix typo
abuelnasr0 103664a
Fix getting HF model output
abuelnasr0 52a1160
Add and to allclose function in checkpoint conversion script
abuelnasr0 8017f4e
Remove allclose check
abuelnasr0 d700931
Add doc for bloom
abuelnasr0 3461a3e
Write batch size instead of _
abuelnasr0 5964185
Rename out_dense to output_dense
abuelnasr0 7c95c10
Rename out_dense to output_dense
abuelnasr0 069f0d2
Format to 80 chars and remove unnecessery check
abuelnasr0 d2514c9
Remove exporting BloomDecoder
abuelnasr0 344c903
Add intermediate_dim Arg
abuelnasr0 b2076ff
Format the code
abuelnasr0 fd9c64c
Remove unnecessery comment
abuelnasr0 4a5a114
Use keras gelu
abuelnasr0 1136ca2
Remove MLP layer and implement it inside BloomDecoder
abuelnasr0 2d03d2c
Split q k v heads
abuelnasr0 2eeb5f4
Remove shapes comments
abuelnasr0 531b1ff
Revert "Split q k v heads"
abuelnasr0 786677b
Revert "Revert "Split q k v heads""
abuelnasr0 eccef98
Revert "Remove shapes comments"
abuelnasr0 ea3063d
Add bias axes
abuelnasr0 38826c4
Add bias axes to the correct axes
abuelnasr0 2ad1d17
Update conversion script for splitting q,k,v
abuelnasr0 b68300f
format the code
abuelnasr0 67f7198
Rename _dropout -> _dropout_layer
abuelnasr0 33ddca2
use clone initializer instead of paasing str name
abuelnasr0 e59e9e2
Serialize kernal & bais initializers
abuelnasr0 56e4492
Format the code
abuelnasr0 d89a207
Add alibi_bias_max to _build_alibi_tensor function
abuelnasr0 0268dd7
Format the code
abuelnasr0 0c0233f
Lowercase vairiable names
abuelnasr0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright 2023 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| # Copyright 2023 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import math | ||
|
|
||
| from keras_nlp.backend import keras | ||
| from keras_nlp.backend import ops | ||
| from keras_nlp.utils.keras_utils import clone_initializer | ||
|
|
||
|
|
||
| class BloomAttention(keras.layers.Layer): | ||
| def __init__( | ||
| self, | ||
| num_heads, | ||
| dropout=0.0, | ||
| kernel_initializer="glorot_uniform", | ||
| bias_initializer="zeros", | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.num_heads = num_heads | ||
| self.dropout = dropout | ||
| self.kernel_initializer = keras.initializers.get(kernel_initializer) | ||
| self.bias_initializer = keras.initializers.get(bias_initializer) | ||
|
|
||
| def build(self, inputs_shape): | ||
| batch_size, seq_length, hidden_dim = inputs_shape | ||
|
|
||
| self.head_dim = hidden_dim // self.num_heads | ||
|
|
||
| # Layer-wise attention scaling | ||
| self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) | ||
|
|
||
| self._query_dense = keras.layers.EinsumDense( | ||
| equation="btm,mnh->btnh", | ||
| output_shape=(None, self.num_heads, self.head_dim), | ||
| bias_axes="nh", | ||
| kernel_initializer=clone_initializer(self.kernel_initializer), | ||
| bias_initializer=clone_initializer(self.bias_initializer), | ||
| dtype=self.dtype_policy, | ||
| name="query_dense", | ||
| ) | ||
| self._query_dense.build(inputs_shape) | ||
|
|
||
| self._key_dense = keras.layers.EinsumDense( | ||
| equation="bsm,mnh->bsnh", | ||
| output_shape=(None, self.num_heads, self.head_dim), | ||
| bias_axes="nh", | ||
| kernel_initializer=clone_initializer(self.kernel_initializer), | ||
| bias_initializer=clone_initializer(self.bias_initializer), | ||
| dtype=self.dtype_policy, | ||
| name="key_dense", | ||
| ) | ||
| self._key_dense.build(inputs_shape) | ||
|
|
||
| self._value_dense = keras.layers.EinsumDense( | ||
| equation="bsm,mnh->bsnh", | ||
| output_shape=(None, self.num_heads, self.head_dim), | ||
| bias_axes="nh", | ||
| kernel_initializer=clone_initializer(self.kernel_initializer), | ||
| bias_initializer=clone_initializer(self.bias_initializer), | ||
| dtype=self.dtype_policy, | ||
| name="value_dense", | ||
| ) | ||
| self._value_dense.build(inputs_shape) | ||
|
|
||
| self._output_dense = keras.layers.Dense( | ||
| hidden_dim, | ||
| kernel_initializer=clone_initializer(self.kernel_initializer), | ||
| bias_initializer=clone_initializer(self.bias_initializer), | ||
| dtype=self.dtype_policy, | ||
| name="output_dense", | ||
| ) | ||
| self._output_dense.build(inputs_shape) | ||
|
|
||
| self._dropout_layer = keras.layers.Dropout( | ||
| rate=self.dropout, dtype=self.dtype_policy, name="dropout" | ||
| ) | ||
| self._softmax = keras.layers.Softmax( | ||
| dtype=self.dtype_policy, name="softmax" | ||
| ) | ||
|
|
||
| self.built = True | ||
|
|
||
| @staticmethod | ||
| def _build_alibi_tensor(num_heads, seq_length, alibi_bias_max=8): | ||
| # this function is adopted from fairseq | ||
| # https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 | ||
| def get_slopes(n): | ||
| def get_slopes_power_of_2(n): | ||
| start = 2 ** ( | ||
| -(2 ** -(math.log2(n) - math.log2(alibi_bias_max))) | ||
| ) | ||
| ratio = start | ||
| return [start * ratio**i for i in range(n)] | ||
|
|
||
| if math.log2(n).is_integer(): | ||
| return get_slopes_power_of_2(n) | ||
| else: | ||
| closest_power_of_2 = 2 ** math.floor(math.log2(n)) | ||
| return ( | ||
| get_slopes_power_of_2(closest_power_of_2) | ||
| + get_slopes(2 * closest_power_of_2)[0::2][ | ||
| : n - closest_power_of_2 | ||
| ] | ||
| ) | ||
|
|
||
| slopes = ops.convert_to_tensor(get_slopes(num_heads), dtype=float) | ||
| slopes = ops.expand_dims(slopes, 1) | ||
|
|
||
| alibi = slopes * ops.expand_dims(ops.arange(seq_length, dtype=float), 0) | ||
| alibi = ops.expand_dims(alibi, 1) | ||
| alibi = ops.expand_dims(alibi, 0) | ||
|
|
||
| return alibi | ||
|
|
||
| def call( | ||
| self, | ||
| hidden_states, | ||
| attention_mask=None, | ||
| cache=None, | ||
| cache_update_index=None, | ||
| ): | ||
| batch_size, seq_length, hidden_dim = ops.shape(hidden_states) | ||
|
|
||
| query = self._query_dense(hidden_states) | ||
| key = self._key_dense(hidden_states) | ||
| value = self._value_dense(hidden_states) | ||
|
|
||
| if cache is not None: | ||
| key_cache = cache[:, 0, ...] | ||
| value_cache = cache[:, 1, ...] | ||
| if cache_update_index is None: | ||
| key = key_cache | ||
| value = value_cache | ||
| else: | ||
| start = [0, cache_update_index, 0, 0] | ||
| key = ops.slice_update(key_cache, start, key) | ||
| value = ops.slice_update(value_cache, start, value) | ||
| cache = ops.stack((key, value), axis=1) | ||
| else: | ||
| if cache_update_index is not None: | ||
| raise ValueError( | ||
| "`cache_update_index` should not be set if `cache` is " | ||
| f"`None`. Received: cache={cache}, " | ||
| f"cache_update_index={cache_update_index}" | ||
| ) | ||
|
|
||
| # query (batch_size, num_heads, query_length, head_dim) | ||
| query = ops.transpose(query, [0, 2, 1, 3]) | ||
| # value (batch_size, num_heads, kv_length, head_dim) | ||
| value = ops.transpose(value, [0, 2, 1, 3]) | ||
| # key (batch_size, num_heads, head_dim, kv_length) | ||
| key = ops.transpose(key, [0, 2, 3, 1]) | ||
|
|
||
| alibi = self._build_alibi_tensor( | ||
| num_heads=self.num_heads, seq_length=seq_length | ||
| ) | ||
|
|
||
| scores = ( | ||
| ops.matmul(query, key) * self.inv_norm_factor + alibi | ||
| ) # [batch_size, num_heads, query_length, kv_length] | ||
|
|
||
| scores = self._softmax(scores, ops.expand_dims(attention_mask, 1)) | ||
|
|
||
| scores = self._dropout_layer(scores) | ||
|
|
||
| attention_output = ops.matmul( | ||
| scores, value | ||
| ) # [batch_size, num_heads, query_length, head_dim] | ||
|
|
||
| attention_output = ops.transpose( | ||
| attention_output, [0, 2, 1, 3] | ||
| ) # [batch_size, query_length, num_heads, head_dim] | ||
| attention_output = ops.reshape( | ||
| attention_output, | ||
| [batch_size, seq_length, self.num_heads * self.head_dim], | ||
| ) # [batch_size, query_length, hidden_dim] | ||
|
|
||
| attention_output = self._output_dense(attention_output) | ||
| attention_output = self._dropout_layer(attention_output) | ||
|
|
||
| if cache is not None: | ||
| return attention_output, cache | ||
|
|
||
| return attention_output | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "num_heads": self.num_heads, | ||
| "dropout": self.dropout, | ||
| "kernel_initializer": keras.initializers.serialize( | ||
| self.kernel_initializer | ||
| ), | ||
| "bias_initializer": keras.initializers.serialize( | ||
| self.bias_initializer | ||
| ), | ||
| } | ||
| ) | ||
| return config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| # Copyright 2023 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from keras_nlp.api_export import keras_nlp_export | ||
| from keras_nlp.backend import keras | ||
| from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding | ||
| from keras_nlp.models.backbone import Backbone | ||
| from keras_nlp.models.bloom.bloom_decoder import BloomDecoder | ||
|
|
||
|
|
||
| def _bloom_kernel_initializer(stddev=0.02): | ||
| return keras.initializers.RandomNormal(stddev=stddev) | ||
|
|
||
|
|
||
| @keras_nlp_export("keras_nlp.models.BloomBackbone") | ||
| class BloomBackbone(Backbone): | ||
| """A Bloom decoder network. | ||
|
|
||
| This network implements a Transformer-based decoder network, BigScience | ||
| Language Open-science Open-access Multilingual (BLOOM), as descriped in | ||
| ["BLOOM: A 176B-Parameter Open-Access Multilingual Language Model"](https://arxiv.org/pdf/2211.05100.pdf). | ||
|
|
||
| The default constructor gives a fully customizable, randomly initialized | ||
| Bloom model with any number of layers, heads, and embedding dimensions. To | ||
| load preset architectures and weights, use the `from_preset()` constructor. | ||
|
|
||
| Disclaimer: Pre-trained models are provided on an "as is" basis, without | ||
| warranties or conditions of any kind. | ||
|
|
||
| Args: | ||
| vocabulary_size: int. The size of the token vocabulary. | ||
| num_layers: int. The number of transformer layers. | ||
| num_heads: int. The number of attention heads for each transformer. | ||
| The hidden size must be divisible by the number of attention heads. | ||
| hidden_dim: int. The dimensionality of the embeddings and hidden states. | ||
| intermediate_dim: int. The output dimension of the first Dense layer in | ||
| the MLP network of each transformer. | ||
| dropout: float. Dropout probability for the Transformer decoder. | ||
| layer_norm_epsilon: float. Epsilon for the layer normalization layers in | ||
| the transformer decoder. | ||
| max_sequence_length: int. The maximum sequence length that this decoder | ||
| can consume. | ||
|
|
||
| Examples: | ||
| ```python | ||
| input_data = { | ||
| "token_ids": np.ones(shape=(1, 12), dtype="int32"), | ||
| "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), | ||
| } | ||
|
|
||
| # Randomly initialized BLOOM decoder with a custom config. | ||
| model = keras_nlp.models.BloomBackbone( | ||
| vocabulary_size=10, | ||
| num_layers=2, | ||
| num_heads=2, | ||
| hidden_dim=32, | ||
| intermediate_dim=32*4, | ||
| dropout=0.0, | ||
| layer_norm_epsilon=1e-5, | ||
| max_sequence_length=128, | ||
| ) | ||
| model(input_data) | ||
| ``` | ||
|
|
||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| vocabulary_size, | ||
| num_layers, | ||
| num_heads, | ||
| hidden_dim, | ||
| intermediate_dim, | ||
| dropout=0.0, | ||
| layer_norm_epsilon=1e-5, | ||
| max_sequence_length=512, | ||
| **kwargs, | ||
| ): | ||
| token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") | ||
| padding_mask = keras.Input( | ||
| shape=(None,), dtype="int32", name="padding_mask" | ||
| ) | ||
|
|
||
| # Embed tokens | ||
| token_embedding_layer = ReversibleEmbedding( | ||
| input_dim=vocabulary_size, | ||
| output_dim=hidden_dim, | ||
| embeddings_initializer=_bloom_kernel_initializer(stddev=0.02), | ||
| tie_weights=False, | ||
| name="token_embedding", | ||
| ) | ||
| token_embedding = token_embedding_layer(token_ids) | ||
|
|
||
| x = keras.layers.LayerNormalization( | ||
| epsilon=layer_norm_epsilon, name="token_embedding_layernorm" | ||
| )(token_embedding) | ||
|
|
||
| for i in range(num_layers): | ||
| x = BloomDecoder( | ||
| num_heads=num_heads, | ||
| intermediate_dim=intermediate_dim, | ||
| dropout=dropout, | ||
| layer_norm_epsilon=layer_norm_epsilon, | ||
| name=f"transformer_layer_{i}", | ||
| )(x, decoder_padding_mask=padding_mask) | ||
|
|
||
| sequence_output = keras.layers.LayerNormalization( | ||
| epsilon=layer_norm_epsilon, name="final_layernorm" | ||
| )(x) | ||
|
|
||
| super().__init__( | ||
| inputs={ | ||
| "token_ids": token_ids, | ||
| "padding_mask": padding_mask, | ||
| }, | ||
| outputs=sequence_output, | ||
| **kwargs, | ||
| ) | ||
| self.vocabulary_size = vocabulary_size | ||
| self.num_layers = num_layers | ||
| self.num_heads = num_heads | ||
| self.hidden_dim = hidden_dim | ||
| self.intermediate_dim = intermediate_dim | ||
| self.dropout = dropout | ||
| self.layer_norm_epsilon = layer_norm_epsilon | ||
| self.max_sequence_length = max_sequence_length | ||
| self.token_embedding = token_embedding_layer | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "vocabulary_size": self.vocabulary_size, | ||
| "num_layers": self.num_layers, | ||
| "num_heads": self.num_heads, | ||
| "hidden_dim": self.hidden_dim, | ||
| "intermediate_dim": self.intermediate_dim, | ||
| "dropout": self.dropout, | ||
| "layer_norm_epsilon": self.layer_norm_epsilon, | ||
| "max_sequence_length": self.max_sequence_length, | ||
| } | ||
| ) | ||
| return config | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.