diff --git a/keras_nlp/layers/modeling/transformer_layer_utils.py b/keras_nlp/layers/modeling/transformer_layer_utils.py index 863da59a36..f375bf1b9d 100644 --- a/keras_nlp/layers/modeling/transformer_layer_utils.py +++ b/keras_nlp/layers/modeling/transformer_layer_utils.py @@ -55,9 +55,12 @@ def compute_causal_mask(batch_size, input_length, output_length, cache_index=0): `(batch_size, output_length, input_length)` that can be passed to a attention layer. """ - i = ops.expand_dims(ops.arange(output_length), axis=1) + cache_index - j = ops.arange(input_length) - mask = ops.expand_dims(ops.cast(i >= j, dtype="int32"), axis=0) + i = ops.arange(output_length, dtype="float32") + i = i + ops.cast(cache_index, "float32") + i = ops.expand_dims(i, axis=1) + j = ops.arange(input_length, dtype="float32") + mask = ops.expand_dims(i >= j, axis=0) + return ops.broadcast_to(mask, (batch_size, output_length, input_length)) diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 8fd6a70ac0..cdd50670f3 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -75,6 +75,13 @@ ) from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_causal_lm import GemmaCausalLM +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_preprocessor import GemmaPreprocessor +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 6fccf6013a..9c8cdaa60e 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -152,3 +152,80 @@ def from_preset(calling_cls, *args, **kwargs): example_preset_name=next(iter(cls.presets), ""), preset_names='", "'.join(cls.presets), )(cls.from_preset.__func__) + + def enable_lora(self, rank): + """Enable Lora on the backbone. + + Calling this method will freeze all weights on the backbone, + while enabling Lora on the query & value `EinsumDense` layers + of the attention layers. + """ + target_names = ["query_dense", "value_dense", "query", "value"] + self.trainable = True + self._lora_enabled_layers = [] + self._lora_rank = rank + for layer in self._flatten_layers(include_self=False): + layer.trainable = False + all_layers = self._flatten_layers(include_self=False) + all_layers = [lyr for lyr in all_layers if lyr.weights] + for i, layer in enumerate(all_layers): + for name in target_names: + if layer.name == name: + if hasattr(layer, "enable_lora"): + layer.trainable = True + layer.enable_lora(rank) + self._lora_enabled_layers.append(i) + + def save_lora_weights(self, filepath): + if not getattr(self, "_lora_enabled_layers", []): + raise ValueError( + "There are no lora-enabled layers in this model. " + "Make sure to call `.enable_lora(rank)` first." + ) + if not str(filepath).endswith(".lora.h5"): + raise ValueError( + "The filename must end in `.lora.h5`. " + f"Received: filepath={filepath}" + ) + + store = keras.src.saving.saving_lib.H5IOStore(filepath, mode="w") + lora_store = store.make("lora") + lora_store["rank"] = self._lora_rank + # We cannot identify layers by name since names are non-unique, + # so we identify them by index in the topologically sorted list + # of layers that have weights. + all_layers = self._flatten_layers(include_self=False) + all_layers = [lyr for lyr in all_layers if lyr.weights] + for layer_index in self._lora_enabled_layers: + # We only lora the einsumdense layers, + # so the factored weights are always named `kernel` + layer = all_layers[layer_index] + inner_store = store.make(f"lora/{layer_index}") + inner_store["lora_kernel_a"] = layer.lora_kernel_a + inner_store["lora_kernel_b"] = layer.lora_kernel_b + store.close() + + def load_lora_weights(self, filepath): + store = keras.src.saving.saving_lib.H5IOStore(filepath, mode="r") + lora_store = store.get("lora") + rank = int(lora_store["rank"][()]) + + if not getattr(self, "_lora_enabled_layers", []): + self.enable_lora(rank) + else: + if self._lora_rank != rank: + raise ValueError( + f"The Lora rank expected by file '{filepath}' " + f"is rank={rank}, but the model was called with " + f"`.enable_lora(rank={self._lora_rank})`. " + "Both ranks must match." + ) + all_layers = self._flatten_layers(include_self=False) + all_layers = [lyr for lyr in all_layers if lyr.weights] + for layer_index in self._lora_enabled_layers: + layer = all_layers[layer_index] + lora_kernel_a = store.get(f"lora/{layer_index}")["lora_kernel_a"] + lora_kernel_b = store.get(f"lora/{layer_index}")["lora_kernel_b"] + layer.lora_kernel_a.assign(lora_kernel_a) + layer.lora_kernel_b.assign(lora_kernel_b) + store.close() diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm.py b/keras_nlp/models/bart/bart_seq_2_seq_lm.py index c17eafdb02..c530555b3d 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm.py @@ -479,6 +479,7 @@ def repeat_tensor(x): mask=decoder_padding_mask, end_token_id=end_token_id, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. diff --git a/keras_nlp/models/gemma/__init__.py b/keras_nlp/models/gemma/__init__.py new file mode 100644 index 0000000000..ba0c2545e4 --- /dev/null +++ b/keras_nlp/models/gemma/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py new file mode 100644 index 0000000000..80c2ac6a63 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_attention.py @@ -0,0 +1,197 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.utils.keras_utils import clone_initializer + + +class CachedGemmaAttention(keras.layers.Layer): + """A cached grouped query attention layer.""" + + def __init__( + self, + head_dim, + num_query_heads, + num_key_value_heads, + kernel_initializer="glorot_uniform", + dropout=0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.dropout = dropout + + self._kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + self.num_key_value_groups = num_query_heads // num_key_value_heads + + def build(self, inputs_shape): + self.hidden_dim = inputs_shape[-1] + + self.query_dense = keras.layers.EinsumDense( + "btd,ndh->btnh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="query", + ) + self.query_dense.build(inputs_shape) + + self.key_dense = keras.layers.EinsumDense( + "bsd,kdh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="key", + ) + self.key_dense.build(inputs_shape) + + self.value_dense = keras.layers.EinsumDense( + "bsd,kdh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="value", + ) + self.value_dense.build(inputs_shape) + + self.dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + + self.output_dense = keras.layers.EinsumDense( + equation="btnh,nhd->btd", + output_shape=(None, self.hidden_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self.output_dense.build( + (None, None, self.num_query_heads, self.head_dim) + ) + self.softmax = keras.layers.Softmax(dtype="float32") + self.built = True + + def _apply_rope(self, x, positions): + """Rope rotate q or k.""" + # TODO: refactor to use RotaryEmbedding layer? + max_wavelength = 10000 + x_shape = ops.shape(x) + freq_exponents = (2.0 / x_shape[-1]) * ops.cast( + ops.arange(x_shape[-1] // 2, dtype="float32"), self.compute_dtype + ) + timescale = max_wavelength**freq_exponents + radians = positions[..., None] / timescale[None, None, :] + radians = radians[..., None, :] + sin, cos = ops.sin(radians), ops.cos(radians) + x1, x2 = ops.split(x, 2, axis=-1) + # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA + # compilation on jax. We should be able to remove this once the + # following PR is in all jax releases we care about: + # https://github.com/openxla/xla/pull/7875 + output = ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) + return ops.reshape(output, x_shape) + + def _compute_attention( + self, + q, + k, + v, + attention_mask, + training=False, + ): + query_normalization = 1 / np.sqrt(self.head_dim) + + q *= ops.cast(query_normalization, dtype=q.dtype) + q_shape = ops.shape(q) + q = ops.reshape( + q, + ( + *q_shape[:-2], + self.num_key_value_heads, + self.num_query_heads // self.num_key_value_heads, + q_shape[-1], + ), + ) + b, q_len, _, _, h = ops.shape(q) + + attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k) + attention_mask = attention_mask[:, None, None, :, :] + orig_dtype = attention_logits.dtype + attention_softmax = self.softmax(attention_logits, mask=attention_mask) + attention_softmax = ops.cast(attention_softmax, orig_dtype) + + if self.dropout: + attention_softmax = self.dropout_layer( + attention_softmax, training=training + ) + + results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v) + return ops.reshape(results, (b, q_len, self.num_query_heads, h)) + + def call( + self, + x, + attention_mask=None, + cache=None, + cache_update_index=0, + training=False, + ): + seq_len = ops.shape(x)[1] + start_index = cache_update_index + positions = ops.cast( + ops.arange(seq_len, dtype="float32"), self.compute_dtype + ) + positions = positions + ops.cast(start_index, self.compute_dtype) + query = self.query_dense(x) + query = self._apply_rope(query, positions) + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + key_update = self.key_dense(x) + key_update = self._apply_rope(key_update, positions) + value_update = self.value_dense(x) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: + key = self.key_dense(x) + key = self._apply_rope(key, positions) + value = self.value_dense(x) + + attention_vec = self._compute_attention( + query, key, value, attention_mask, training=training + ) + + # Wipe attn vec if there are no attended tokens. + no_attended_tokens = ops.all( + ops.equal(attention_mask, 0), axis=-1, keepdims=True + )[..., None] + attention_vec = ops.where( + no_attended_tokens, ops.zeros_like(attention_vec), attention_vec + ) + + attention_output = self.output_dense(attention_vec) + + if cache is not None: + return attention_output, cache + return attention_output diff --git a/keras_nlp/models/gemma/gemma_backbone.py b/keras_nlp/models/gemma/gemma_backbone.py new file mode 100644 index 0000000000..e5814940aa --- /dev/null +++ b/keras_nlp/models/gemma/gemma_backbone.py @@ -0,0 +1,267 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import config +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.gemma.gemma_decoder_block import GemmaDecoderBlock +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.models.gemma.rms_normalization import RMSNormalization +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.GemmaBackbone") +class GemmaBackbone(Backbone): + """Gemma core network with hyperparameters. + + This backbone implements the base Transformer network for the Gemma model. + It includes the embedding lookups and transformer layers. This backbone + will output the final hidden states for each token, not generative + predictions over the vocabulary space. For a higher-level object for text + generation, see `keras_nlp.models.GemmaCausalLM`. + + The default constructor gives a fully customizable, randomly initialized + Gemma model with any number of layers, heads, and embedding dimensions. To + load preset architectures and weights, use the `from_preset` constructor. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_query_heads: int. The number of heads for the query projections in + the attention layer. + num_key_value_heads: int. The number of heads for the key and value + projections in the attention layer. + hidden_dim: int. The size of the transformer hidden state at the end + of each transformer layer. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + head_dim: int. The size of each attention head. + layer_norm_epsilon: float. The epsilon value user for every layer norm + in the transformer model. + dropout: float. Dropout probability for the Transformer encoder. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the models computations and weights. Note that some + computations, such as softmax and layer normalization will always + be done a float32 precision regardless of dtype. + + Example usage: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Gemma decoder. + model = keras_nlp.models.GemmaBackbone.from_preset("gemma_2b_en") + model(input_data) + + # Randomly initialized Gemma decoder with custom config. + model = keras_nlp.models.GemmaBackbone( + vocabulary_size=50257, + num_layers=12, + num_query_heads=12, + num_key_value_heads=1, + hidden_dim=768, + intermediate_dim=3072, + head_dim=64, + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + num_key_value_heads, + hidden_dim, + intermediate_dim, + head_dim, + layer_norm_epsilon=1e-6, + dropout=0, + dtype=None, + **kwargs, + ): + if not config.keras_3(): + raise ValueError( + "`GemmaBackbone` requires Keras 3. Run `pip install -U keras` " + "upgrade your Keras version, or see https://keras.io/getting_started/ " + "for more info on Keras versions and installation." + ) + + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=True, + embeddings_initializer=keras.initializers.VarianceScaling( + scale=1.0, + mode="fan_in", + distribution="untruncated_normal", + seed=None, + ), + dtype=dtype, + name="token_embedding", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = GemmaDecoderBlock( + intermediate_dim=intermediate_dim, + hidden_dim=hidden_dim, + num_query_heads=num_query_heads, + head_dim=head_dim, + num_key_value_heads=num_key_value_heads, + dropout=dropout, + dtype=dtype, + name=f"decoder_block_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = RMSNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="final_normalization", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="float32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="float32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + x = x * ops.cast(ops.sqrt(hidden_dim), x.dtype) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.head_dim = head_dim + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "head_dim": self.head_dim, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + } + ) + return config + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @staticmethod + def get_layout_map(device_mesh, model_parallel_dim_name="model"): + """Get a `keras.distribution.LayoutMap` for model parallel distribution. + + The returned `LayoutMap` contains the sharding spec for the gemma + backbone weights, so that you can use it to distribute weights across + the accelerators. + + Sample usage: + ``` + # Feel free to change the mesh shape to balance data and model parallel + mesh = keras.distribution.DeviceMesh( + shape=(1, 8), axis_names=('batch', 'model'), + devices=keras.distribution.list_devices()) + layout_map = GemmaBackbone.get_layout_map( + mesh, model_parallel_dim_name="model") + + distribution = keras.distribution.ModelParallel( + mesh, layout_map, batch_dim_name='batch') + with distribution.scope(): + gemma_model = keras_nlp.models.GemmaCausalLM.from_preset() + ``` + + Args: + device_mesh: The `keras.distribution.DeviceMesh` instance for + distribution. + model_parallel_dim_name: The axis name of the device mesh, where + the weights should be partition on. + Return: + `keras.distribution.LayoutMap` that contains the sharding spec + of all the model weights. + """ + # The weight path and shape of the Gemma backbone is like below (for 2G) + # token_embedding/embeddings, (256128, 2048), 524550144 + # repeat block for decoder + # ... + # decoder_block_17/pre_attention_norm/scale, (2048,), 2048 + # decoder_block_17/attention/query/kernel, (8, 2048, 256), 4194304 + # decoder_block_17/attention/key/kernel, (8, 2048, 256), 4194304 + # decoder_block_17/attention/value/kernel, (8, 2048, 256), 4194304 + # decoder_block_17/attention/attention_output/kernel, (8, 256, 2048), 4194304 + # decoder_block_17/pre_ffw_norm/scale, (2048,), 2048 + # decoder_block_17/ffw_gating/kernel, (2048, 16384), 33554432 + # decoder_block_17/ffw_gating_2/kernel, (2048, 16384), 33554432 + # decoder_block_17/ffw_linear/kernel, (16384, 2048), 33554432 + if not isinstance(device_mesh, keras.distribution.DeviceMesh): + raise ValueError( + "Invalid device_mesh type. Expected `keras.distribution.Device`," + f" got {type(device_mesh)}" + ) + if model_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{model_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + model_dim = model_parallel_dim_name + # The sharding is partition for the hidden_dim of the model. + layout_map = keras.distribution.LayoutMap(device_mesh) + layout_map["token_embedding/embeddings"] = (None, model_dim) + layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ( + None, + model_dim, + None, + ) + layout_map["decoder_block.*attention_output.*kernel"] = ( + None, + None, + model_dim, + ) + layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None) + layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim) + + return layout_map diff --git a/keras_nlp/models/gemma/gemma_backbone_test.py b/keras_nlp/models/gemma/gemma_backbone_test.py new file mode 100644 index 0000000000..c66d318fd5 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_backbone_test.py @@ -0,0 +1,128 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 256128, + "num_layers": 2, + "num_query_heads": 4, + "num_key_value_heads": 4, + "hidden_dim": 128, + "intermediate_dim": 256, + "head_dim": 128, + "layer_norm_epsilon": 1e-6, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=GemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 128), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=GemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=GemmaBackbone, + preset="gemma_2b_en", + input_data={ + "token_ids": ops.array([[651, 4320, 8426, 25341, 235265]]), + "padding_mask": ops.ones((1, 5), dtype="int32"), + }, + expected_output_shape=(1, 5, 2048), + # The forward pass from a preset should be stable! + expected_partial_output=ops.array( + [1.073359, 0.262374, 0.170238, 0.605402, 2.336161] + ), + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaBackbone.presets: + self.run_preset_test( + cls=GemmaBackbone, + preset=preset, + input_data=self.input_data, + ) + + def test_architecture_characteristics(self): + model = GemmaBackbone(**self.init_kwargs) + self.assertEqual(model.count_params(), 33407616) + self.assertEqual(len(model.layers), 6) + + def test_distribution(self): + if keras.backend.backend() != "jax": + return + devices = keras.distribution.list_devices("CPU") + if len(devices) == 1: + # Need more than 1 device for distribution testing. + return + device_mesh = keras.distribution.DeviceMesh( + shape=(1, len(devices)), + axis_names=("batch", "model"), + devices=devices, + ) + + layout_map = GemmaBackbone.get_layout_map(device_mesh) + distribution = keras.distribution.ModelParallel(device_mesh, layout_map) + with distribution.scope(): + model = GemmaBackbone(**self.init_kwargs) + + for w in model.weights: + if "token_embedding/embeddings" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) + if "attention/query/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, "model", None) + ) + if "attention/key/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, "model", None) + ) + if "attention/value/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, "model", None) + ) + if "attention/attention_output/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, None, "model") + ) + if "ffw_gating/kernel" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + if "ffw_gating_2/kernel" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + if "ffw_linearl" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py new file mode 100644 index 0000000000..45c7c6abe0 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm.py @@ -0,0 +1,441 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.GemmaCausalLM") +class GemmaCausalLM(GenerativeTask): + """An end-to-end Gemma model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a Gemma model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"greedy"` sampling will be used. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to string inputs during + `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default + when creating the model with `from_preset()`. + + Args: + backbone: A `keras_nlp.models.GemmaBackbone` instance. + preprocessor: A `keras_nlp.models.GemmaCausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + + Examples: + + Use `generate()` to do text generation. + ```python + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma_lm.generate("I want to say", max_length=30) + + # Generate with batched prompts. + gemma_lm.generate(["This is a", "Where are you"], max_length=30) + ``` + + Compile the `generate()` function with a custom sampler. + ```python + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma_lm.compile(sampler="top_k") + gemma_lm.generate("I want to say", max_length=30) + + gemma_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=2)) + gemma_lm.generate("I want to say", max_length=30) + ``` + + Use `generate()` without preprocessing. + ```python + prompt = { + # Token ids for " Keras is". + "token_ids": np.array([[2, 214064, 603, 0, 0, 0, 0]] * 2), + # Use `"padding_mask"` to indicate values that should not be overridden. + "padding_mask": np.array([[1, 1, 1, 0, 0, 0, 0]] * 2), + } + + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset( + "gemma_2b_en", + preprocessor=None, + ) + gemma_lm.generate(prompt) + ``` + + Call `fit()` on a single batch. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` without preprocessing. + ```python + x = { + # Token ids for " Keras is deep learning library" + "token_ids": np.array([[2, 214064, 603, 5271, 6044, 9581, 1, 0]] * 2), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 0]] * 2), + } + y = np.array([[214064, 603, 5271, 6044, 9581, 3, 0, 0]] * 2) + sw = np.array([[1, 1, 1, 1, 1, 1, 0, 0]] * 2) + + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset( + "gemma_2b_en", + preprocessor=None, + ) + gemma_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) + ``` + + Custom backbone and vocabulary. + ```python + tokenizer = keras_nlp.models.GemmaTokenizer( + proto="proto.spm", + ) + preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor( + tokenizer=tokenizer, + sequence_length=128, + ) + backbone = keras_nlp.models.GemmaBackbone( + vocabulary_size=30552, + num_layers=4, + num_heads=4, + hidden_dim=256, + intermediate_dim=512, + max_sequence_length=128, + ) + gemma_lm = keras_nlp.models.GemmaCausalLM( + backbone=backbone, + preprocessor=preprocessor, + ) + gemma_lm.fit(x=features, batch_size=2) + ``` + """ + + def __init__( + self, + backbone, + preprocessor=None, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Default compilation === + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(2e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], + sampler="greedy", + jit_compile=True, + ) + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @classproperty + def backbone_cls(cls): + return GemmaBackbone + + @classproperty + def preprocessor_cls(cls): + return GemmaCausalLMPreprocessor + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `GemmaCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs in the + whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + x = x * ops.cast(ops.sqrt(self.backbone.hidden_dim), x.dtype) + # Each decoder layer has a cache; we update them separately. + caches = [] + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + current_cache = cache[:, i, ...] + x, next_cache = transformer_layer( + x, + cache=current_cache, + cache_update_index=cache_update_index, + ) + caches.append(next_cache) + cache = ops.stack(caches, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.head_dim + shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + end_token_id=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + end_token_id: The id of the end token to stop on. If all + sequences have produced a new `end_token_id`, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self._sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + end_token_id=end_token_id, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if end_token_id is not None: + # Build a mask of `end_token_id` locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = ops.logical_and( + ops.equal(token_ids, end_token_id), + ops.logical_not(padding_mask), + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `GemmaCausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the GemmaBackbone and isn't influential on + the computation of this function. If omitted, this function uses + `keras.ops.ones()` to create a tensor of the appropriate shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. _This index _is not_ an + index into `self.backbone.layers`_. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Examples: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset( + "gemma_2b_en" + ) + generations = gemma_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = gemma_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = gemma_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + x = token_embeddings * ops.cast( + ops.sqrt(self.backbone.hidden_dim), dtype=self.compute_dtype + ) + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py new file mode 100644 index 0000000000..20c66edff3 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py @@ -0,0 +1,173 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_preprocessor import GemmaPreprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.GemmaCausalLMPreprocessor") +class GemmaCausalLMPreprocessor(GemmaPreprocessor): + """Gemma Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.GemmaCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.GemmaCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.GemmaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor.from_preset( + "gemma_2b_en" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Apply tokenization to a `tf.data.Dataset`. + features = tf.constant(["The quick brown fox.", "Call me Ishmael."]) + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Prepare tokens for generation (no end token). + preprocessor.generate_preprocess(["The quick brown fox jumped."]) + + # Map generation outputs back to strings. + preprocessor.generate_postprocess({ + 'token_ids': np.array([[2, 714, 4320, 8426, 25341, 32292, 235265, 0]]), + 'padding_mask': np.array([[ 1, 1, 1, 1, 1, 1, 1, 0]]), + }) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`GemmaCausalLMPreprocessor` generates `y` and `sample_weight` " + "based on your input data, but your data already contains `y` " + "or `sample_weight`. Your `y` and `sample_weight` will be " + "ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Covert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Covert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + if not self.built: + self.build(None) + + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + token_ids = ops.convert_to_numpy(token_ids) + mask = ops.convert_to_numpy(padding_mask) + # Also strip any special tokens during detokenization (e.g. the start + # and end markers). In the future we could make this configurable. + mask = mask & (token_ids != self.tokenizer.start_token_id) + mask = mask & (token_ids != self.tokenizer.pad_token_id) + mask = mask & (token_ids != self.tokenizer.end_token_id) + token_ids = tf.ragged.boolean_mask(token_ids, mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..121621da85 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py @@ -0,0 +1,92 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["the quick brown fox"] + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=GemmaCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 4, 9, 5, 7, 2, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + [[4, 9, 5, 7, 2, 0, 0, 0]], # Labels shifted. + [[1, 1, 1, 1, 1, 0, 0, 0]], # Zero out unlabeled examples. + ), + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + + preprocessor = GemmaCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[4, 9, 5, 7, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[9, 5, 7, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "the quick brown fox" + preprocessor = GemmaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 4, 9, 5, 7, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 4, 9, 5, 7, 2, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0], + } + preprocessor = GemmaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the quick brown fox") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaCausalLMPreprocessor.presets: + self.run_preset_test( + cls=GemmaCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/gemma/gemma_causal_lm_test.py b/keras_nlp/models/gemma/gemma_causal_lm_test.py new file mode 100644 index 0000000000..0e1d7a14f8 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm_test.py @@ -0,0 +1,245 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch + +import keras +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_causal_lm import GemmaCausalLM +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaCausalLMTest(TestCase): + def setUp(self): + self.tokenizer = GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.preprocessor = GemmaCausalLMPreprocessor( + self.tokenizer, + sequence_length=8, + ) + self.backbone = GemmaBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=2, + num_key_value_heads=1, + hidden_dim=4, + intermediate_dim=8, + head_dim=2, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["the quick brown fox", "the quick brown fox"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=GemmaCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, 11), + ) + + def test_generate(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate("the quick brown fox") + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + + def test_generate_with_bfloat16(self): + original_floatx = keras.config.floatx() + keras.config.set_floatx("float16") + try: + causal_lm = GemmaCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate("the quick brown fox") + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + finally: + # Restore floatx to the original value to prevent impact on other + # tests even if there is an exception. + keras.config.set_floatx(original_floatx) + + def test_early_stopping(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the quick"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the quick brown fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the quick brown fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=GemmaCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaCausalLM.presets: + self.run_preset_test( + cls=GemmaCausalLM, + preset=preset, + input_data=self.input_data, + ) + + def test_score_logits(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = GemmaCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8, 11) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_loss(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = GemmaCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + target_ids = keras.ops.roll(token_ids, shift=-1, axis=1) + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="loss", + target_ids=target_ids, + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_layer_intercept_fn_exfiltration(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = GemmaCausalLM(**self.init_kwargs) + expected_embedded_shape = (2, 8, 4) + expected_score_shape = (2, 8, 11) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Setup a custom intercept function that extracts the embeddings to a + # a variable from the embeddings layer and otherwise asserts on shapes. + embedded_prompts = None + + def layer_intercept_fn_for_testing(x, i): + if i == -1: + nonlocal embedded_prompts + embedded_prompts = x + else: + nonlocal expected_embedded_shape + self.assertEqual(ops.shape(x), expected_embedded_shape) + return x + + # Get the scores. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + layer_intercept_fn=layer_intercept_fn_for_testing, + ) + + # Assert shapes for info exfiltrated into the parent context. + self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape) + self.assertEqual(ops.shape(scores), expected_score_shape) diff --git a/keras_nlp/models/gemma/gemma_decoder_block.py b/keras_nlp/models/gemma/gemma_decoder_block.py new file mode 100644 index 0000000000..0a91655fc4 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_decoder_block.py @@ -0,0 +1,189 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_nlp.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_nlp.models.gemma.gemma_attention import CachedGemmaAttention +from keras_nlp.models.gemma.rms_normalization import RMSNormalization + + +class GemmaDecoderBlock(keras.layers.Layer): + def __init__( + self, + hidden_dim, + intermediate_dim, + head_dim, + num_query_heads, + num_key_value_heads, + layer_norm_epsilon=1e-6, + dropout=0, + **kwargs, + ): + super().__init__(**kwargs) + + self.intermediate_dim = intermediate_dim + self.hidden_dim = hidden_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + + self.pre_attention_norm = RMSNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="pre_attention_norm", + ) + + self.attention = CachedGemmaAttention( + head_dim=head_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + dropout=dropout, + dtype=self.dtype_policy, + name="attention", + ) + + if self.dropout > 0: + self.attention_dropout = keras.layers.Dropout(rate=dropout) + self.feedforward_dropout = keras.layers.Dropout(rate=dropout) + + self.pre_ffw_norm = RMSNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="pre_ffw_norm", + ) + + self.gating_ffw = keras.layers.EinsumDense( + equation="btd,df->btf", + output_shape=(None, self.intermediate_dim // 2), + dtype=self.dtype_policy, + name="ffw_gating", + ) + + self.gating_ffw_2 = keras.layers.EinsumDense( + equation="btd,df->btf", + output_shape=(None, self.intermediate_dim // 2), + dtype=self.dtype_policy, + name="ffw_gating_2", + ) + + self.ffw_linear = keras.layers.EinsumDense( + equation="btf,fd->btd", + output_shape=(None, self.hidden_dim), + dtype=self.dtype_policy, + name="ffw_linear", + ) + + def build(self, input_shape): + self.pre_attention_norm.build(input_shape) + self.attention.build(input_shape) + + shape = input_shape + self.pre_ffw_norm.build(shape) + self.gating_ffw.build(shape) + self.gating_ffw_2.build(shape) + + shape = self.gating_ffw.compute_output_shape(shape) + self.ffw_linear.build(shape) + self.built = True + + def compute_output_shape(self, input_shape): + # Isometric + return input_shape + + def _compute_attention_mask( + self, x, padding_mask, cache, cache_update_index + ): + decoder_mask = merge_padding_and_attention_mask( + inputs=x, padding_mask=padding_mask, attention_mask=None + ) + batch_size = ops.shape(x)[0] + input_length = output_length = ops.shape(x)[1] + if cache is not None: + input_length = ops.shape(cache)[2] + + causal_mask = compute_causal_mask( + batch_size=batch_size, + input_length=input_length, + output_length=output_length, + cache_index=cache_update_index, + ) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def call( + self, + x, + padding_mask=None, + cache=None, + cache_update_index=0, + ): + normalized_x = self.pre_attention_norm(x) + attention_mask = self._compute_attention_mask( + normalized_x, padding_mask, cache, cache_update_index + ) + if cache is not None: + attention, new_cache = self.attention( + normalized_x, + attention_mask=attention_mask, + cache=cache, + cache_update_index=cache_update_index, + ) + else: + attention = self.attention( + normalized_x, + attention_mask=attention_mask, + ) + + if self.dropout: + attention = self.attention_dropout(attention) + + attention_x = x + attention + normalized_x = self.pre_ffw_norm(attention_x) + + x1 = self.gating_ffw(normalized_x) + x2 = self.gating_ffw_2(normalized_x) + x = keras.activations.gelu(x1, approximate=True) * x2 + x = self.ffw_linear(x) + + x = x + attention_x + + if cache is not None: + return x, new_cache + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "head_dim": self.head_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_nlp/models/gemma/gemma_lora_test.py b/keras_nlp/models/gemma/gemma_lora_test.py new file mode 100644 index 0000000000..1cbbdfa67f --- /dev/null +++ b/keras_nlp/models/gemma/gemma_lora_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import pytest + +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaLoraTest(TestCase): + def setUp(self): + self._init_kwargs = { + "vocabulary_size": 50, + "num_layers": 2, + "num_query_heads": 2, + "num_key_value_heads": 2, + "hidden_dim": 32, + "intermediate_dim": 16, + "head_dim": 16, + "layer_norm_epsilon": 1e-6, + } + + def test_lora_fine_tuning(self): + # Set up backbone and preprocessor. + backbone = GemmaBackbone(**self._init_kwargs) + backbone.enable_lora(4) + # 4 layers, 2 weights per layer + self.assertLen(backbone.trainable_weights, 4 * 2) + self.assertLen(backbone.non_trainable_weights, 20) + input_data = { + "token_ids": np.ones((2, 5), dtype="int32"), + "padding_mask": np.ones((2, 5), dtype="int32"), + } + targets = np.random.normal(size=(2, 5, self._init_kwargs["hidden_dim"])) + + # Test fine-tuning + backbone.compile(optimizer="sgd", loss="mse") + backbone.fit(input_data, targets, epochs=1) + + # Test saving and reloading. + temp_filepath = os.path.join( + self.get_temp_dir(), "lora_model.weights.h5" + ) + backbone.save_weights(temp_filepath) + new_backbone = GemmaBackbone(**self._init_kwargs) + new_backbone.load_weights(temp_filepath) + ref_out = backbone(input_data) + new_out = new_backbone(input_data) + self.assertAllClose(ref_out, new_out) + + def test_lora_saving_and_reloading(self): + backbone = GemmaBackbone(**self._init_kwargs) + initial_model_filepath = os.path.join( + self.get_temp_dir(), "base.weights.h5" + ) + backbone.save_weights(initial_model_filepath) + + backbone.enable_lora(4) + input_data = { + "token_ids": np.ones((2, 5), dtype="int32"), + "padding_mask": np.ones((2, 5), dtype="int32"), + } + targets = np.random.normal(size=(2, 5, self._init_kwargs["hidden_dim"])) + backbone.compile(optimizer="sgd", loss="mse") + backbone.fit(input_data, targets, epochs=1) + + lora_filepath = os.path.join(self.get_temp_dir(), "lora_model.lora.h5") + backbone.save_lora_weights(lora_filepath) + + # New backbone with same initial weights + new_backbone = GemmaBackbone(**self._init_kwargs) + new_backbone.load_weights(initial_model_filepath) + new_backbone.enable_lora(4) + new_backbone.load_lora_weights(lora_filepath) + + ref_out = backbone(input_data) + new_out = new_backbone(input_data) + self.assertAllClose(ref_out, new_out) + + # Test exceptions + backbone = GemmaBackbone(**self._init_kwargs) + with self.assertRaisesRegex(ValueError, "no lora-enabled layers"): + backbone.save_lora_weights(lora_filepath) + backbone.enable_lora(5) + with self.assertRaisesRegex(ValueError, "ranks must match"): + backbone.load_lora_weights(lora_filepath) + with self.assertRaisesRegex(ValueError, "filename must end in"): + backbone.save_lora_weights("bad_filepath") diff --git a/keras_nlp/models/gemma/gemma_preprocessor.py b/keras_nlp/models/gemma/gemma_preprocessor.py new file mode 100644 index 0000000000..8fc3beb48c --- /dev/null +++ b/keras_nlp/models/gemma/gemma_preprocessor.py @@ -0,0 +1,199 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.GemmaPreprocessor") +class GemmaPreprocessor(Preprocessor): + """Gemma preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do 2 things: + + - Tokenize the inputs using the `tokenizer`. + - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`, that can + be passed directly to a `keras_nlp.models.GemmaBackbone`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + The call method of this layer accepts three arguments, `x`, `y`, and + `sample_weight`. `x` can be a python string or tensor representing a single + segment, a list of python strings representing a batch of single segments, + or a list of tensors representing multiple segments to be packed together. + `y` and `sample_weight` are both optional, can have any format, and will be + passed through unaltered. + + `GemmaPreprocessor` expects the input to have only one segment, as Gemma is + mainly used for generation tasks. For tasks having multi-segment inputs + please combine inputs into a single string input before passing to the + preprocessor layer. + + Args: + tokenizer: A `keras_nlp.models.GemmaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + + Directly calling the layer on data. + ```python + preprocessor = keras_nlp.models.GemmaPreprocessor.from_preset( + "gemma_2b_en" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Custom vocabulary. + bytes_io = io.BytesIO() + ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=ds.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=8, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + tokenizer = keras_nlp.models.GemmaTokenizer( + proto=bytes_io.getvalue(), + ) + preprocessor = keras_nlp.models.GemmaPreprocessor(tokenizer=tokenizer) + preprocessor("The quick brown fox jumped.") + ``` + + Apply preprocessing to a `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.GemmaPreprocessor.from_preset( + "gemma_2b_en" + ) + + text = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + label = tf.constant([1, 1]) + + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((text, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(text) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def __init__( + self, + tokenizer, + sequence_length=8192, + add_start_token=True, + add_end_token=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.tokenizer = tokenizer + self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "GemmaPreprocessor requires each input to contain only " + f"one segment, but received {len(x)}. If you are using Gemma " + "for a multi-segment classification task, please combine your " + "input into a single string." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @classproperty + def tokenizer_cls(cls): + return GemmaTokenizer diff --git a/keras_nlp/models/gemma/gemma_preprocessor_test.py b/keras_nlp/models/gemma/gemma_preprocessor_test.py new file mode 100644 index 0000000000..f54a509979 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_preprocessor_test.py @@ -0,0 +1,74 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.gemma.gemma_preprocessor import GemmaPreprocessor +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["the quick brown fox"] + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=GemmaPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output={ + "token_ids": [[1, 4, 9, 5, 7, 2, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + preprocessor = GemmaPreprocessor( + tokenizer=self.tokenizer, + sequence_length=8, + add_start_token=False, + add_end_token=False, + ) + x = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[4, 9, 5, 7, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + + def test_sequence_length_override(self): + input_data = "the quick brown fox" + preprocessor = GemmaPreprocessor(**self.init_kwargs) + x = preprocessor(input_data, sequence_length=4) + self.assertAllEqual(x["token_ids"], [1, 4, 9, 2]) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaPreprocessor.presets: + self.run_preset_test( + cls=GemmaPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/gemma/gemma_presets.py b/keras_nlp/models/gemma/gemma_presets.py new file mode 100644 index 0000000000..f63fef17fa --- /dev/null +++ b/keras_nlp/models/gemma/gemma_presets.py @@ -0,0 +1,66 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gemma model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = { + "gemma_2b_en": { + "metadata": { + "description": ( + "18-layer Gemma model (Gemma with 2B parameters). " + ), + "params": 2506172416, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_2b_en/1", + }, + "gemma_instruct_2b_en": { + "metadata": { + "description": ( + "18-layer Gemma model (Gemma with 2B parameters). " + ), + "params": 2506172416, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_2b_en/1", + }, + "gemma_7b_en": { + "metadata": { + "description": ( + "28-layer Gemma model (Gemma with 7B parameters). " + ), + "params": 8537680896, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/1", + }, + "gemma_instruct_7b_en": { + "metadata": { + "description": ( + "28-layer Gemma model (Gemma with 7B parameters). " + ), + "params": 8537680896, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/1", + }, +} diff --git a/keras_nlp/models/gemma/gemma_tokenizer.py b/keras_nlp/models/gemma/gemma_tokenizer.py new file mode 100644 index 0000000000..6a4bb76ea0 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_tokenizer.py @@ -0,0 +1,108 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.GemmaTokenizer") +class GemmaTokenizer(SentencePieceTokenizer): + """Gemma tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_nlp.tokenizers.SentencePieceTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by + Gemma models and provides a `from_preset()` method to automatically + download a matching vocabulary for a Gemma preset. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, or a + `bytes` object with a serialized SentencePiece proto. See the + [SentencePiece repository](https://github.com/google/sentencepiece) + for more details on the format. + + Examples: + + ```python + # Unbatched input. + tokenizer = keras_nlp.models.GemmaTokenizer.from_preset("gemma_2b_en") + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + bytes_io = io.BytesIO() + ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=ds.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=8, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + tokenizer = keras_nlp.models.GemmaTokenizer( + proto=bytes_io.getvalue(), + ) + tokenizer("The quick brown fox jumped.") + ``` + """ + + def __init__(self, proto, **kwargs): + self.start_token = "" + self.end_token = "" + self.pad_token = "" + + super().__init__(proto=proto, **kwargs) + + def set_proto(self, proto): + super().set_proto(proto) + if proto is not None: + for token in [self.end_token, self.pad_token]: + if token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your " + "`vocabulary` or use a pretrained `vocabulary` name." + ) + self.start_token_id = self.token_to_id(self.start_token) + self.end_token_id = self.token_to_id(self.end_token) + self.pad_token_id = self.token_to_id(self.pad_token) + else: + self.start_token_id = None + self.end_token_id = None + self.pad_token_id = None + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/gemma/gemma_tokenizer_test.py b/keras_nlp/models/gemma/gemma_tokenizer_test.py new file mode 100644 index 0000000000..1c617dd937 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_tokenizer_test.py @@ -0,0 +1,67 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaTokenizerTest(TestCase): + def setUp(self): + self.init_kwargs = { + # Generated using create_gemma_test_proto.py + "proto": os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ) + } + self.input_data = ["the quick brown fox", "the earth is round"] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=GemmaTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[4, 9, 5, 7], [4, 6, 8, 10]], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + GemmaTokenizer( + # Generated using create_no_special_token_proto.py + proto=os.path.join( + self.get_test_data_dir(), "no_special_token_vocab.spm" + ) + ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=GemmaTokenizer, + preset="gemma_2b_en", + input_data=["The quick brown fox."], + expected_output=[[651, 4320, 8426, 25341, 235265]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaTokenizer.presets: + self.run_preset_test( + cls=GemmaTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/gemma/rms_normalization.py b/keras_nlp/models/gemma/rms_normalization.py new file mode 100644 index 0000000000..ce9bdaf880 --- /dev/null +++ b/keras_nlp/models/gemma/rms_normalization.py @@ -0,0 +1,40 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.backend import keras +from keras_nlp.backend import ops + + +class RMSNormalization(keras.layers.Layer): + def __init__(self, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + + def build(self, input_shape): + self.scale = self.add_weight( + name="scale", + trainable=True, + shape=(input_shape[-1],), + initializer="zeros", + ) + self.built = True + + def call(self, x): + # Always compute normalization in float32. + x = ops.cast(x, "float32") + scale = ops.cast(self.scale, "float32") + var = ops.mean(ops.square(x), axis=-1, keepdims=True) + normed_inputs = x * ops.reciprocal(ops.sqrt(var + 1e-06)) + normed_inputs = normed_inputs * (1 + scale) + return ops.cast(normed_inputs, self.compute_dtype) diff --git a/keras_nlp/models/generative_task.py b/keras_nlp/models/generative_task.py index 9a461926e4..598217d964 100644 --- a/keras_nlp/models/generative_task.py +++ b/keras_nlp/models/generative_task.py @@ -101,12 +101,7 @@ def compiled_generate_function(inputs, end_token_id, state): for v in self._sampler.variables: new_v = scope.get_current_value(v) sampler_variables.append(new_v if new_v is not None else v) - state = ( - sampler_variables, - trainable_variables, - non_trainable_variables, - ) - return outputs, state + return outputs, sampler_variables def wrapped_generate_function( inputs, @@ -115,18 +110,20 @@ def wrapped_generate_function( # Create an explicit tuple of all variable state. state = ( self._sampler.variables, - self.trainable_variables, - self.non_trainable_variables, + # Use the explicit variable.value to preserve the + # sharding spec of distribution. + [v.value for v in self.trainable_variables], + [v.value for v in self.non_trainable_variables], ) inputs = tree.map_structure(ops.convert_to_tensor, inputs) - outputs, state = compiled_generate_function( + outputs, sampler_variables = compiled_generate_function( inputs, end_token_id, state, ) # Only assign the sampler variables (random seeds), as other # model variables should never be updated in generation. - for ref_v, v in zip(self._sampler.variables, state[0]): + for ref_v, v in zip(self._sampler.variables, sampler_variables): ref_v.assign(v) return outputs diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index e154c88bb1..b0bd529da4 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -298,6 +298,7 @@ def next(prompt, cache, index): mask=padding_mask, end_token_id=end_token_id, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py index bef32017ea..b1df4a6706 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py @@ -188,6 +188,7 @@ def next(prompt, cache, index): mask=padding_mask, end_token_id=end_token_id, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. diff --git a/keras_nlp/models/opt/opt_causal_lm.py b/keras_nlp/models/opt/opt_causal_lm.py index 9715bc6b75..2ca8ee07b4 100644 --- a/keras_nlp/models/opt/opt_causal_lm.py +++ b/keras_nlp/models/opt/opt_causal_lm.py @@ -294,6 +294,7 @@ def next(prompt, cache, index): mask=padding_mask, end_token_id=end_token_id, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. diff --git a/keras_nlp/models/t5/t5_transformer_layer.py b/keras_nlp/models/t5/t5_transformer_layer.py index 697af20899..ddff7a164a 100644 --- a/keras_nlp/models/t5/t5_transformer_layer.py +++ b/keras_nlp/models/t5/t5_transformer_layer.py @@ -131,8 +131,7 @@ def call( shape = ops.shape(hidden_states) batch_size, length = shape[0], shape[1] causal_mask = compute_causal_mask(batch_size, length, length) - attention_mask = ops.cast(attention_mask, "int32") - attention_mask = causal_mask & attention_mask + attention_mask = causal_mask & ops.cast(attention_mask, "bool") x = hidden_states # Intermediate result. diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 87948439a8..297ec203de 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -72,6 +72,7 @@ def __call__( mask=None, end_token_id=None, hidden_states=None, + model=None, ): batch_size, max_length = ops.shape(prompt)[0], ops.shape(prompt)[1] index = ops.cast(index, "int32") @@ -167,6 +168,7 @@ def gather_beams(x): body=body, loop_vars=(prompt, cache, index, log_probs), maximum_iterations=(max_length - index), + model=model, ) all_prompts = unflatten_beams(prompt) diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index 8b3d52d9a5..4259167c8c 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -73,6 +73,7 @@ def __call__( mask=None, end_token_id=None, hidden_states=None, + model=None, ): if hidden_states is None: raise ValueError( @@ -209,6 +210,7 @@ def gather_best_token(beams): body=body, loop_vars=(prompt, cache, index, logits, hidden_states), maximum_iterations=(max_length - index), + model=model, ) return prompt diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 2101c9277d..3ecf16ac28 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -92,6 +92,7 @@ def __call__( mask=None, end_token_id=None, hidden_states=None, + model=None, ): max_length = ops.shape(prompt)[-1] # Make sure `max_length` and `index` are the same dtype. @@ -133,6 +134,7 @@ def body(prompt, cache, index): body, loop_vars=(prompt, cache, index), maximum_iterations=(max_length - index), + model=model, ) return prompt @@ -147,32 +149,68 @@ def compute_probabilities(self, logits): probs = keras.activations.softmax(logits / self.temperature) return ops.cast(probs, logits_dtype) - def run_loop(self, cond, body, loop_vars=None, maximum_iterations=None): + def run_loop( + self, cond, body, model=None, loop_vars=None, maximum_iterations=None + ): """Run ops.while_loops with a `StatelessScope` if necessary.""" if config.backend() == "jax": + import itertools + + if model: + model_trainable_variables = model.trainable_variables + model_non_trainable_variables = model.non_trainable_variables + else: + model_trainable_variables = [] + model_non_trainable_variables = [] - def stateless_cond(variables, *loop_vars): + def stateless_cond(state, *loop_vars): return cond(*loop_vars) - def stateless_body(variables, *loop_vars): - mapping = zip(self.variables, variables) + def stateless_body(state, *loop_vars): + ( + sampler_variables, + trainable_variables, + non_trainable_variables, + ) = state + mapping = itertools.chain( + zip(self.variables, sampler_variables), + zip(model_trainable_variables, trainable_variables), + zip(model_non_trainable_variables, non_trainable_variables), + ) with keras.StatelessScope(state_mapping=mapping) as scope: loop_vars = body(*loop_vars) - variables = [] + sampler_variables = [] for v in self.variables: new_v = scope.get_current_value(v) - variables.append(new_v if new_v is not None else v) - return variables, *loop_vars + sampler_variables.append(new_v if new_v is not None else v) + state = ( + sampler_variables, + trainable_variables, + non_trainable_variables, + ) + return state, *loop_vars variables = [ops.convert_to_tensor(v) for v in self.variables] - variables, *loop_vars = ops.while_loop( + trainable_variables = [ + ops.convert_to_tensor(v) for v in model_trainable_variables + ] + non_trainable_variables = [ + ops.convert_to_tensor(v) for v in model_non_trainable_variables + ] + state = ( + variables, + trainable_variables, + non_trainable_variables, + ) + state, *loop_vars = ops.while_loop( cond=stateless_cond, body=stateless_body, - loop_vars=(variables, *loop_vars), + loop_vars=(state, *loop_vars), maximum_iterations=maximum_iterations, ) - [ref_v.assign(v) for ref_v, v in zip(self.variables, variables)] + for ref_v, v in zip(self.variables, state[0]): + ref_v.assign(v) else: loop_vars = ops.while_loop( cond=cond, diff --git a/keras_nlp/tests/test_data/gemma_test_vocab.spm b/keras_nlp/tests/test_data/gemma_test_vocab.spm new file mode 100644 index 0000000000..a049c032c2 Binary files /dev/null and b/keras_nlp/tests/test_data/gemma_test_vocab.spm differ diff --git a/keras_nlp/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/tokenizers/sentence_piece_tokenizer.py index ae655aceb6..64e169939c 100644 --- a/keras_nlp/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/tokenizers/sentence_piece_tokenizer.py @@ -253,6 +253,8 @@ def tokenize(self, inputs): def detokenize(self, inputs): self._check_vocabulary() inputs, unbatched, _ = convert_to_ragged_batch(inputs) + # tf-text sentencepiece does not handle int64. + inputs = tf.cast(inputs, "int32") outputs = self._sentence_piece.detokenize(inputs) if unbatched: outputs = tf.squeeze(outputs, 0) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index 6bb2748fd9..01c11a3db1 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -16,6 +16,7 @@ import json import os +from keras_nlp.backend import config as backend_config from keras_nlp.backend import keras try: @@ -180,6 +181,13 @@ def load_from_preset( # Optionally load weights. load_weights = load_weights and config["weights"] if load_weights: + # For jax, delete all previous allocated memory to avoid temporarily + # duplicating variable allocations. torch and tensorflow have stateful + # variable types and do not need this fix. + if backend_config.backend() == "jax": + for weight in layer.weights: + if getattr(weight, "_value", None) is not None: + weight._value.delete() weights_path = get_file(preset, config["weights"]) layer.load_weights(weights_path) diff --git a/tools/gemma/export_gemma_to_hf.py b/tools/gemma/export_gemma_to_hf.py new file mode 100644 index 0000000000..31e3f3c69b --- /dev/null +++ b/tools/gemma/export_gemma_to_hf.py @@ -0,0 +1,328 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +import torch +import transformers +from absl import app +from absl import flags + +import keras_nlp + +os.environ["KERAS_BACKEND"] = "torch" + +""" +Sample usage: + +For converting a keras model to HuggingFace format using a custom or fine-tuned +checkpoint from Keras, make sure to pass the path for the Keras weights file +(ending in `.weights.h5`), the model size (`2b` or `7b`), and the tokenizer +vocabulary file (`.spm`, `.model`, or equivalent) to +`--weights_file`, `--size`, and `--vocab_path`, respectively. + +Optionally, you can specify the output directory +for the converted model at `--output_dir`. (defaults to `gg_hf`) +``` +python tools/gemma/export_gemma_to_hf.py \ + --weights_file fine_tuned_imdb.weights.h5 \ + --size 2b \ + --vocab_path gemma_lm_tokenizer/vocabulary.spm \ + --output_dir fine_tuned_gg_hf +``` + +For converting a Keras model to HuggingFace format from a preset, +simply pass the Keras preset name to `--preset` and its model size +(`2b` or `7b`) to `--size`. +``` +python tools/gemma/export_gemma_to_hf.py \ + --preset gemma_2b_en \ + --size 2b \ + --output_dir keras_hf_model/ +``` +""" + + +PRESET_MAP = { + "gemma_2b_en": "gg-hf/gemma-2b", + "gemma_instruct_2b_en": "gg-hf/gemma-2b", + "gemma_7b_en": "gg-hf/gemma-7b", + "gemma_instruct_7b_en": "gg-hf/gemma-7b", +} + +SIZE_MAP = { + "2b": ("gg-hf/gemma-2b", "gemma_2b_en"), + "7b": ("gg-hf/gemma-7b", "gemma_7b_en"), +} + +gemma_2b_config = transformers.GemmaConfig( + num_hidden_layers=18, + num_attention_heads=8, + num_key_value_heads=1, + hidden_size=2048, + intermediate_size=16384, +) + +gemma_7b_config = transformers.GemmaConfig() + +CONFIG_MAPPING = {"2b": gemma_2b_config, "7b": gemma_7b_config} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "hf_token", + None, + "Your HuggingFace token. Needed for access to the HuggingFace Gemma" + "implementation since the repository is private, for now.", +) +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}' + " Alternatively, a Keras weights file (`.weights.h5`) can be passed" + " to --weights_file flag.", +) +flags.DEFINE_string( + "weights_file", + None, + "A Keras weights file (`.weights.h5`)." + " Alternatively, a model preset can be passed to --preset flag.", +) +flags.DEFINE_string( + "size", + None, + "Size of model. Must be passed if `weights_file` is passed. " + "This should be either `2b` or `7b`.", +) +flags.DEFINE_string( + "output_dir", + "gg_hf", + "An output directory for the converted HuggingFace model and tokenizer.", +) +flags.DEFINE_string( + "vocab_path", + None, + "A path containing the vocabulary (must be a `.spm` file or equivalent). " + "If not passed, the vocabulary of the preset will be used.", +) + + +def convert_checkpoints(preset, weights_file, size, output_dir, vocab_path): + if preset is not None: + hf_id = PRESET_MAP[preset] + print(f"\n-> Loading KerasNLP Gemma model with preset `{preset}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset(preset) + else: + hf_id, keras_preset = SIZE_MAP[size.lower()] + print(f"\n-> Loading Keras weights from file `{weights_file}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset( + keras_preset + ) + keras_nlp_model.load_weights(weights_file) + + print(f"\n-> Loading HuggingFace Gemma `{size.upper()}` model...") + hf_model = transformers.GemmaForCausalLM(CONFIG_MAPPING[size.lower()]) + + print("\n✅ Model loading complete.") + print("\n-> Converting weights from KerasNLP Gemma to HuggingFace Gemma...") + + # Token embedding (with vocab size difference handling) + keras_embedding = keras_nlp_model.backbone.token_embedding.weights[0] + hf_vocab_size = hf_model.model.embed_tokens.weight.shape[0] + keras_nlp_vocab_size = keras_embedding.value.shape[0] + if hf_vocab_size < keras_nlp_vocab_size: + diff = keras_nlp_vocab_size - hf_vocab_size + update_state_dict( + hf_model.model.embed_tokens, + "weight", + keras_embedding.value[:-diff, :], + ) + else: + update_state_dict( + hf_model.model.embed_tokens, + "weight", + keras_embedding.value, + ) + + # Decoder blocks + for i in range(keras_nlp_model.backbone.num_layers): + decoder_block = keras_nlp_model.backbone.get_layer(f"decoder_block_{i}") + + # Pre-attention norm + update_state_dict( + hf_model.model.layers[i].input_layernorm, + "weight", + decoder_block.pre_attention_norm.weights[0].value, + ) + + # Attention + query_target_shape = hf_model.model.layers[ + i + ].self_attn.q_proj.weight.shape + query_tensor = decoder_block.attention.query_dense.weights[0].value + query_tensor = query_tensor.transpose(1, 2).reshape(query_target_shape) + update_state_dict( + hf_model.model.layers[i].self_attn.q_proj, "weight", query_tensor + ) + + key_target_shape = hf_model.model.layers[ + i + ].self_attn.k_proj.weight.shape + key_tensor = decoder_block.attention.key_dense.weights[0].value + key_tensor = key_tensor.transpose(1, 2).reshape(key_target_shape) + update_state_dict( + hf_model.model.layers[i].self_attn.k_proj, "weight", key_tensor + ) + + value_target_shape = hf_model.model.layers[ + i + ].self_attn.v_proj.weight.shape + value_tensor = decoder_block.attention.value_dense.weights[0].value + value_tensor = value_tensor.transpose(1, 2).reshape(value_target_shape) + update_state_dict( + hf_model.model.layers[i].self_attn.v_proj, "weight", value_tensor + ) + + out_target_shape = hf_model.model.layers[ + i + ].self_attn.o_proj.weight.shape + keras_out_tensor = decoder_block.attention.output_dense.weights[0].value + out_tensor = keras_out_tensor.reshape( + (out_target_shape[1], out_target_shape[0]) # Transpose target size + ).transpose(0, 1) + + update_state_dict( + hf_model.model.layers[i].self_attn.o_proj, "weight", out_tensor + ) + + # Post-attention norm + update_state_dict( + hf_model.model.layers[i].post_attention_layernorm, + "weight", + decoder_block.pre_ffw_norm.weights[0].value, + ) + + # MLP (Feed-forward) + update_state_dict( + hf_model.model.layers[i].mlp.gate_proj, + "weight", + decoder_block.gating_ffw.weights[0].value.transpose(0, 1), + ) + update_state_dict( + hf_model.model.layers[i].mlp.up_proj, + "weight", + decoder_block.gating_ffw_2.weights[0].value.transpose(0, 1), + ) + update_state_dict( + hf_model.model.layers[i].mlp.down_proj, + "weight", + decoder_block.ffw_linear.weights[0].value.transpose(0, 1), + ) + + # Final norm + update_state_dict( + hf_model.model.norm, + "weight", + keras_nlp_model.backbone.layers[-1].weights[0].value, + ) + + print("\n✅ Weights converted successfully.") + print(f"\n-> Saving HuggingFace model to `{output_dir}`...") + + # Save model to HF Transformers format + os.makedirs(output_dir, exist_ok=True) + hf_model.save_pretrained(output_dir) + + print(f"\n✅ Saving complete. Model saved at `{output_dir}`.") + + # Tokenizer + + if not vocab_path: + tokenizer_preset = preset or SIZE_MAP[size.lower()] + print( + "\n-> Loading KerasNLP Gemma tokenizer with " + f"preset `{tokenizer_preset}`..." + ) + keras_nlp_tokenizer = keras_nlp.models.GemmaTokenizer.from_preset( + tokenizer_preset + ) + # Save tokenizer state + keras_nlp_tokenizer.save_assets(output_dir) + vocab_path = os.path.join(output_dir, "vocabulary.spm") + print("\n✅ Tokenizer loading complete.") + + hf_tokenizer = transformers.GemmaTokenizer(vocab_path) + + print(f"\n-> Saving HuggingFace Gemma tokenizer to `{output_dir}`...") + # Save tokenizer to HF Transformers format + hf_tokenizer.save_pretrained(output_dir) + + print(f"\n✅ Saving complete. Tokenizer saved at `{output_dir}`.") + + +def update_state_dict(layer, weight_name: str, tensor: torch.Tensor) -> None: + """Updates the state dict for a weight given a tensor.""" + assert ( + tensor.shape == layer.state_dict()[weight_name].shape + ), f"{tensor.shape} vs {layer.state_dict()[weight_name].shape}" + layer.state_dict()[weight_name].copy_(tensor) + + +def flag_error_handler(): + if not FLAGS.preset and not FLAGS.weights_file: + raise ValueError( + "Please pass either a valid Keras preset to `--preset`" + " or supply a Keras weights file (`.weights.h5`) and model size" + " (`2b` or `7b`) to `--weights_file` and `--size`, respectively." + ) + if FLAGS.weights_file: + if FLAGS.preset: + raise ValueError( + "Both `--preset` and `--weights_file` flags cannot be supplied " + "at the same time. Either supply a valid Keras preset to " + "`--preset`or supply a Keras `.weights.h5` file and " + "model size (`2b` or `7b`) to `--weights_file` and `--size`, " + "respectively." + ) + if not str(FLAGS.weights_file).endswith(".weights.h5"): + raise ValueError( + "Please pass a valid Keras weights file ending in `.weights.h5`." + ) + if not FLAGS.size: + raise ValueError( + "The `size` flag must be passed if a weights file is passed. " + "Please pass the appropriate size (`2b` or `7b`) for your " + "model to the `--size` flag." + ) + if FLAGS.size.lower() not in ["2b", "7b"]: + raise ValueError( + "Invalid `size`. Please pass the appropriate size (`2b` or `7b`) " + "for your model to the `--size` flag." + ) + + +def main(_): + flag_error_handler() + convert_checkpoints( + FLAGS.preset, + FLAGS.weights_file, + FLAGS.size, + FLAGS.output_dir, + FLAGS.vocab_path, + ) + + +if __name__ == "__main__": + flags.mark_flag_as_required("size") + app.run(main) diff --git a/tools/gemma/export_gemma_to_torch_xla.py b/tools/gemma/export_gemma_to_torch_xla.py new file mode 100644 index 0000000000..005eac272d --- /dev/null +++ b/tools/gemma/export_gemma_to_torch_xla.py @@ -0,0 +1,322 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import os + +import gemma +import torch +import torch_xla.core.xla_model as xm +from absl import app +from absl import flags +from gemma import model_xla as gemma_model + +import keras_nlp + +os.environ["KERAS_BACKEND"] = "torch" + +""" +Sample usage: + +For converting a Keras model to PyTorch format using a custom or fine-tuned +checkpoint from Keras, make sure to pass the path for the Keras weights file +(ending in `.weights.h5`) and the model size (`2b` or `7b`) to `--weights_file` +and `--size`, respectively. + +Optionally, you can specify the output path for the converted model at +`--output_file`. (This defaults to `gemma.ckpt`) +``` +python tools/gemma/export_gemma_to_torch_xla.py \ + --weights_file fine_tuned_imdb.weights.h5 \ + --size 2b \ + --output_file fine_tuned_imdb.ckpt +``` + +For converting a Keras model to PyTorch format from a preset, +simply pass the Keras preset name to `--preset`. +``` +python tools/gemma/export_gemma_to_torch_xla.py \ + --preset gemma_2b_en \ + --output_file path/to/keras_torch_model.ckpt +``` +""" + + +PRESET_MAP = { + "gemma_2b_en": gemma.config.get_config_for_2b(), + "gemma_instruct_2b_en": gemma.config.get_config_for_2b(), + "gemma_7b_en": gemma.config.get_config_for_7b(), + "gemma_instruct_7b_en": gemma.config.get_config_for_7b(), +} + +SIZE_MAP = { + "2b": (gemma.config.get_config_for_2b(), "gemma_2b_en"), + "7b": (gemma.config.get_config_for_7b(), "gemma_7b_en"), +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}' + " Alternatively, a Keras weights file (`.weights.h5`) can be passed" + " to --weights_file flag.", +) +flags.DEFINE_string( + "weights_file", + None, + "A Keras weights file (`.weights.h5`)." + " Alternatively, a model preset can be passed to --preset flag.", +) +flags.DEFINE_string( + "size", + None, + "Size of model. Must be passed if `weights_file` is passed. " + "This should be either `2b` or `7b`.", +) +flags.DEFINE_string( + "output_file", + "gemma.ckpt", + "An output file for the converted PyTorch checkpoint. Default: `gemma.ckpt`", +) +flags.DEFINE_string( + "vocab_dir", + "gemma_tokenizer", + "A directory in which the vocabulary for the tokenizer will be stored.", +) +flags.DEFINE_string( + "dtype", + "float32", + "Set the precision of the converted checkpoint. Must be a valid PyTorch dtype.", +) + + +@contextlib.contextmanager +def _set_default_tensor_type(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(torch.float) + + +def _reconcile_attention_dims(qkv, target_shape): + return torch.cat(qkv).reshape(tuple(target_shape)) + + +def convert_checkpoints(preset, weights_file, size, output_file, vocab_dir): + device = xm.xla_device() + + if preset is not None: + print( + f"\n-> Loading PyTorch Gemma model config for preset `{preset}`..." + ) + model = gemma_model.GemmaForCausalLM( + PRESET_MAP[preset], world_size=1, rank=0, device=device + ) + print(f"\n-> Loading KerasNLP Gemma model with preset `{preset}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset(preset) + else: + print(f"\n-> Loading PyTorch Gemma model config for `{size}` model...") + config, size_preset = SIZE_MAP[size.lower()] + model = gemma_model.GemmaForCausalLM( + config, world_size=1, rank=0, device=device + ) + print(f"\n-> Loading Keras weights from file `{weights_file}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset( + size_preset + ) + keras_nlp_model.load_weights(weights_file) + + print("\n✅ Model loading complete.") + print("\n-> Converting weights from KerasNLP Gemma to PyTorch Gemma...") + + # Token embedding (with vocab size difference handling) + keras_embedding = keras_nlp_model.backbone.token_embedding.weights[0] + torch_vocab_size = model.embedder.weight.shape[0] + keras_nlp_vocab_size = keras_embedding.value.shape[0] + if torch_vocab_size < keras_nlp_vocab_size: + diff = keras_nlp_vocab_size - torch_vocab_size + update_state_dict( + model.embedder, + "weight", + keras_embedding.value[:-diff, :], + ) + else: + update_state_dict( + model.embedder, + "weight", + keras_embedding.value, + ) + + # Decoder blocks + for i in range(keras_nlp_model.backbone.num_layers): + decoder_block = keras_nlp_model.backbone.get_layer(f"decoder_block_{i}") + # Pre-attention norm + update_state_dict( + model.model.layers[i].input_layernorm, + "weight", + decoder_block.pre_attention_norm.weights[0].value, + ) + + # Attention + qkv = ( + decoder_block.attention.query_dense.weights[0].value.transpose( + 1, 2 + ), + decoder_block.attention.key_dense.weights[0].value.transpose(1, 2), + decoder_block.attention.value_dense.weights[0].value.transpose( + 1, 2 + ), + ) + qkv_target_shape = model.model.layers[i].self_attn.qkv_proj.weight.shape + combined_tensor = _reconcile_attention_dims(qkv, qkv_target_shape) + + update_state_dict( + model.model.layers[i].self_attn.qkv_proj, "weight", combined_tensor + ) + + out_target_shape = model.model.layers[i].self_attn.o_proj.weight.shape + keras_out_tensor = decoder_block.attention.output_dense.weights[0].value + out_tensor = keras_out_tensor.reshape( + (out_target_shape[1], out_target_shape[0]) # Transpose target size + ).transpose(0, 1) + + update_state_dict( + model.model.layers[i].self_attn.o_proj, "weight", out_tensor + ) + + # Post-attention norm + update_state_dict( + model.model.layers[i].post_attention_layernorm, + "weight", + decoder_block.pre_ffw_norm.weights[0].value, + ) + + # MLP (Feed-forward) + update_state_dict( + model.model.layers[i].mlp.gate_proj, + "weight", + decoder_block.gating_ffw.weights[0].value.transpose(0, 1), + ) + update_state_dict( + model.model.layers[i].mlp.up_proj, + "weight", + decoder_block.gating_ffw_2.weights[0].value.transpose(0, 1), + ) + update_state_dict( + model.model.layers[i].mlp.down_proj, + "weight", + decoder_block.ffw_linear.weights[0].value.transpose(0, 1), + ) + + # Final norm + update_state_dict( + model.model.norm, + "weight", + keras_nlp_model.backbone.layers[-1].weights[0].value, + ) + + print("\n✅ Weights converted successfully.") + print(f"\n-> Saving PyTorch model checkpoint to `{output_file}`...") + + # Save model checkpoint + torch.save({"model_state_dict": model.state_dict()}, output_file) + + print( + f"\n✅ Saving complete. Model checkpoint available at `{output_file}`." + ) + + if preset is not None: + # Tokenizer + print( + f"\n-> Loading KerasNLP Gemma tokenizer with preset `{preset}`..." + ) + keras_nlp_tokenizer = keras_nlp.models.GemmaTokenizer.from_preset( + preset + ) + print("\n✅ Model loading complete.") + print(f"\n-> Saving tokenizer state to directory `{vocab_dir}`...") + + # Save tokenizer state + os.makedirs(vocab_dir, exist_ok=True) + keras_nlp_tokenizer.save_assets(vocab_dir) + + print( + "\n✅ Saving complete. Tokenizer state " + f"available at `{vocab_dir}/vocabulary.spm`." + ) + + +def update_state_dict(layer, weight_name: str, tensor: torch.Tensor) -> None: + """Updates the state dict for a weight given a tensor.""" + assert ( + tensor.shape == layer.state_dict()[weight_name].shape + ), f"{tensor.shape} vs {layer.state_dict()[weight_name].shape}" + layer.state_dict()[weight_name].copy_(tensor) + + +def flag_error_handler(): + if not FLAGS.preset and not FLAGS.weights_file: + raise ValueError( + "Please pass either a valid Keras preset to `--preset`" + " or supply a Keras weights file (`.weights.h5`) and model size" + " (`2b` or `7b`) to `--weights_file` and `--size`, respectively." + ) + if FLAGS.weights_file: + if FLAGS.preset: + raise ValueError( + "Both `--preset` and `--weights_file` flags cannot be supplied " + "at the same time. Either supply a valid Keras preset to " + "`--preset`or supply a Keras `.weights.h5` file and " + "model size (`2b` or `7b`) to `--weights_file` and `--size`, " + "respectively." + ) + if not str(FLAGS.weights_file).endswith(".weights.h5"): + raise ValueError( + "Please pass a valid Keras weights file ending in `.weights.h5`." + ) + if not FLAGS.size: + raise ValueError( + "The `size` flag must be passed if a weights file is passed. " + "Please pass the appropriate size (`2b` or `7b`) for your " + "model to the `--size` flag." + ) + if FLAGS.size.lower() not in ["2b", "7b"]: + raise ValueError( + "Invalid `size`. Please pass the appropriate size (`2b` or `7b`) " + "for your model to the `--size` flag." + ) + if FLAGS.dtype: + dtype = getattr(torch, FLAGS.dtype) + if not isinstance(dtype, torch.dtype): + raise ValueError( + "Invalid `dtype`. Please pass a valid PyTorch data type (e.g. " + "`float32', 'float16`, etc.) to the `--dtype` flag." + ) + + +def main(_): + flag_error_handler() + with _set_default_tensor_type(getattr(torch, FLAGS.dtype)): + convert_checkpoints( + FLAGS.preset, + FLAGS.weights_file, + FLAGS.size, + FLAGS.output_file, + FLAGS.vocab_dir, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/tools/gemma/run_gemma_xla.py b/tools/gemma/run_gemma_xla.py new file mode 100644 index 0000000000..9fa50cbd2b --- /dev/null +++ b/tools/gemma/run_gemma_xla.py @@ -0,0 +1,287 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import os +import random +import sys +from typing import List + +import gemma.xla_model_parallel as xla_model_parallel +import numpy as np +import torch +import torch.multiprocessing +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +from absl import app +from absl import flags +from gemma.config import GemmaConfig +from gemma.config import get_config_for_2b +from gemma.config import get_config_for_7b +from gemma.model_xla import GemmaForCausalLM +from gemma.tokenizer import Tokenizer + +PAD_TOKEN_ID = -1 + +FILE_PATH = "gemma.ckpt" +TOKENIZER_DIR = "gemma_tokenizer" + +PRESET_MAP = { + "gemma_2b_en": get_config_for_2b(), + "gemma_instruct_2b_en": get_config_for_2b(), + "gemma_7b_en": get_config_for_7b(), + "gemma_instruct_7b_en": get_config_for_7b(), +} + +SIZE_MAP = { + "2b": get_config_for_2b(), + "7b": get_config_for_7b(), +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' +) +flags.DEFINE_string( + "size", + None, + "Size of model. Must be passed if `preset` is not passed. " + "This should be either `2b` or `7b`.", +) +flags.DEFINE_string( + "checkpoint_file", + "gemma.ckpt", + "A PyTorch checkpoint file containing the converted weights.", +) +flags.DEFINE_string( + "vocab_file", + "gemma_tokenizer/vocabulary.spm", + "The file containing the vocabulary for the tokenizer.", +) +flags.DEFINE_string( + "prompt", + "The capital of France is", + "A test prompt for verifying functionality of the PyTorch Gemma model.", +) + +# This is a modified version of `run_xla.py` script in the Hex-LLM Gemma repo +# to ensure proper functionality after porting checkpoints from Keras. + + +@contextlib.contextmanager +def _set_default_tensor_type(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(torch.float) + + +def generate( + i: int, + model_config: GemmaConfig, + checkpoint_file: str, + vocab_file: str, + prompts: List[str], + output_lens: List[int], + temperatures: List[float], + top_ps: List[float], + top_ks: List[int], +): + # Set seed from config + seed = model_config.seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + device = xm.xla_device() + xm.set_rng_state(seed, device) + + rank = xla_model_parallel.get_model_parallel_rank() + world_size = xla_model_parallel.get_model_parallel_world_size() + if rank > 0: + sys.stdout = open(os.devnull, "w") + + # Load model with ported weights and place on device + with _set_default_tensor_type(model_config.get_dtype()): + model = GemmaForCausalLM(model_config, world_size, rank, device) + model.load_weights(checkpoint_file) + model = model.to(device).eval() + + # Create tokenizer with saved Keras tokenizer state + tokenizer = Tokenizer(vocab_file) + + prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts] + min_prompt_len = min(len(p) for p in prompt_tokens) + + batch_size = len(prompts) + assert batch_size == len(temperatures) + assert batch_size == len(top_ps) + assert batch_size == len(top_ks) + max_seq_len = max([len(p) + o for p, o in zip(prompt_tokens, output_lens)]) + assert max_seq_len <= model_config.max_position_embeddings + if model_config.num_key_value_heads < world_size: + assert world_size % model_config.num_key_value_heads == 0 + n_local_heads = 1 + else: + assert model_config.num_key_value_heads % world_size == 0 + n_local_heads = model_config.num_key_value_heads // world_size + + # build KV caches + kv_caches = [] + for _ in range(model_config.num_hidden_layers): + k_cache = torch.zeros( + size=( + batch_size, + max_seq_len, + n_local_heads, + model_config.head_dim, + ), + dtype=model_config.get_dtype(), + device=device, + ) + v_cache = torch.zeros( + size=( + batch_size, + max_seq_len, + n_local_heads, + model_config.head_dim, + ), + dtype=model_config.get_dtype(), + device=device, + ) + kv_caches.append((k_cache, v_cache)) + + # prepare inputs + token_ids_tensor = torch.full( + (batch_size, max_seq_len), PAD_TOKEN_ID, dtype=torch.int64 + ) + input_token_ids_tensor = torch.full( + (batch_size, min_prompt_len), PAD_TOKEN_ID, dtype=torch.int64 + ) + for i, p in enumerate(prompt_tokens): + token_ids_tensor[i, : len(p)] = torch.tensor(p) + input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( + p[:min_prompt_len] + ) + token_ids_tensor = token_ids_tensor.to(device) + prompt_mask_tensor = token_ids_tensor != PAD_TOKEN_ID + input_token_ids_tensor = input_token_ids_tensor.to(device) + input_positions_tensor = torch.arange( + 0, min_prompt_len, dtype=torch.int64 + ).to(device) + mask_tensor = torch.full( + (1, 1, max_seq_len, max_seq_len), -2.3819763e38 + ).to(torch.float) + mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device) + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device) + temperatures_tensor = torch.FloatTensor(temperatures).to(device) + top_ps_tensor = torch.FloatTensor(top_ps).to(device) + top_ks_tensor = torch.LongTensor(top_ks).to(device) + output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device) + xm.mark_step() + + # Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output. + for i in range(max_seq_len - min_prompt_len): + next_token_ids = model( + input_token_ids=input_token_ids_tensor, + input_positions=input_positions_tensor, + kv_write_indices=None, + kv_caches=kv_caches, + mask=curr_mask_tensor, + output_positions=output_positions_tensor, + temperatures=temperatures_tensor, + top_ps=top_ps_tensor, + top_ks=top_ks_tensor, + ) + curr_prompt_mask = prompt_mask_tensor.index_select( + 1, output_index + ).squeeze(dim=1) + curr_token_ids = token_ids_tensor.index_select(1, output_index).squeeze( + dim=1 + ) + output_token_ids = torch.where( + curr_prompt_mask, curr_token_ids, next_token_ids + ).unsqueeze(dim=1) + token_ids_tensor.index_copy_(1, output_index, output_token_ids) + + input_token_ids_tensor = output_token_ids + input_positions_tensor = output_index + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device) + output_index = output_index + 1 + xm.mark_step() + + # Detokenization. + token_ids = token_ids_tensor.tolist() + results = [] + for i, tokens in enumerate(token_ids): + trimmed_output = tokens[ + len(prompt_tokens[i]) : len(prompt_tokens[i]) + output_lens[i] + ] + if tokenizer.eos_id in trimmed_output: + eos_index = trimmed_output.index(tokenizer.eos_id) + trimmed_output = trimmed_output[:eos_index] + results.append(tokenizer.decode(trimmed_output)) + + for prompt, result in zip(prompts, results): + print("======================================") + print(f"PROMPT: {prompt}") + print(f"RESULT: {result}") + print("======================================") + + +def flag_error_handler(): + if not FLAGS.preset and not FLAGS.size: + raise ValueError( + "Please pass either a valid Keras preset to `--preset`" + " or supply a model size (`2b` or `7b`) to `--size`." + ) + if FLAGS.size and FLAGS.size.lower() not in ["2b", "7b"]: + raise ValueError( + "Invalid `size`. Please pass the appropriate size (`2b` or `7b`) " + "for your model to the `--size` flag." + ) + + +def main(_): + flag_error_handler() + if FLAGS.preset: + model_config = PRESET_MAP[FLAGS.preset] + else: + model_config = SIZE_MAP[FLAGS.size.lower()] + prompts = [ + FLAGS.prompt, + ] + n = len(prompts) + output_lengths = [10] * n + temperatures = [0.95] * n + top_ps = [1.0] * n + top_ks = [100] * n + xmp.spawn( + generate, + args=( + model_config, + FLAGS.checkpoint_file, + FLAGS.vocab_file, + prompts, + output_lengths, + temperatures, + top_ps, + top_ks, + ), + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/tools/sentencepiece_testing/create_gemma_test_proto.py b/tools/sentencepiece_testing/create_gemma_test_proto.py new file mode 100644 index 0000000000..c3ce418a4b --- /dev/null +++ b/tools/sentencepiece_testing/create_gemma_test_proto.py @@ -0,0 +1,36 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tools.sentencepiece_testing.utils import train_sentencepiece + + +def main(): + train_sentencepiece( + ["the quick brown fox", "the earth is round"], + "gemma_test_vocab.spm", + vocab_size=11, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + + +if __name__ == "__main__": + main()