diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 5795c85532..d432e61d5d 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -152,6 +152,13 @@ ) from keras_nlp.src.models.opt.opt_preprocessor import OPTPreprocessor from keras_nlp.src.models.opt.opt_tokenizer import OPTTokenizer +from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone +from keras_nlp.src.models.phi3.phi3_causal_lm import Phi3CausalLM +from keras_nlp.src.models.phi3.phi3_causal_lm_preprocessor import ( + Phi3CausalLMPreprocessor, +) +from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor +from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.src.models.roberta.roberta_classifier import RobertaClassifier diff --git a/keras_nlp/src/models/__init__.py b/keras_nlp/src/models/__init__.py index 9e122f35ad..bd134f4f4d 100644 --- a/keras_nlp/src/models/__init__.py +++ b/keras_nlp/src/models/__init__.py @@ -143,6 +143,13 @@ ) from keras_nlp.src.models.opt.opt_preprocessor import OPTPreprocessor from keras_nlp.src.models.opt.opt_tokenizer import OPTTokenizer +from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone +from keras_nlp.src.models.phi3.phi3_causal_lm import Phi3CausalLM +from keras_nlp.src.models.phi3.phi3_causal_lm_preprocessor import ( + Phi3CausalLMPreprocessor, +) +from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor +from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.src.models.roberta.roberta_classifier import RobertaClassifier diff --git a/keras_nlp/src/models/phi3/__init__.py b/keras_nlp/src/models/phi3/__init__.py new file mode 100644 index 0000000000..2fcd425fdc --- /dev/null +++ b/keras_nlp/src/models/phi3/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone +from keras_nlp.src.models.phi3.phi3_presets import backbone_presets +from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer +from keras_nlp.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, (Phi3Backbone, Phi3Tokenizer)) diff --git a/keras_nlp/src/models/phi3/phi3_attention.py b/keras_nlp/src/models/phi3/phi3_attention.py new file mode 100644 index 0000000000..072f233018 --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_attention.py @@ -0,0 +1,259 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.backend import keras +from keras_nlp.src.backend import ops +from keras_nlp.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_nlp.src.models.phi3.phi3_rotary_embedding import ( + Phi3SuScaledRotaryEmbedding, +) +from keras_nlp.src.utils.keras_utils import clone_initializer + + +class Phi3Attention(keras.layers.Layer): + """A cached grounded query attention layer.""" + + def __init__( + self, + num_query_heads, + num_key_value_heads, + kernel_initializer="glorot_uniform", + dropout=0, + max_sequence_length=4096, + pretraining_sequence_length=4096, + rope_max_wavelength=10000, + rope_scaling_type=None, + rope_scaling_short_factor=None, + rope_scaling_long_factor=None, + **kwargs, + ): + super().__init__(**kwargs) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.num_key_value_groups = num_query_heads // num_key_value_heads + self.dropout = dropout + + self.max_sequence_length = max_sequence_length + self.pretraining_sequence_length = pretraining_sequence_length + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_type = rope_scaling_type + self.rope_scaling_short_factor = rope_scaling_short_factor + self.rope_scaling_long_factor = rope_scaling_long_factor + + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + + def build(self, inputs_shape): + # Einsum variables: + # b = batch size + # q = query length + # k = key/value length + # m = model dim + # u = num query heads + # v = num key/value heads + # h = head dim + hidden_dim = inputs_shape[-1] + head_dim = hidden_dim // self.num_query_heads + self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype)) + + self.query_dense = keras.layers.EinsumDense( + equation="bqm,muh->bquh", + output_shape=(None, self.num_query_heads, head_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="query", + ) + self.query_dense.build(inputs_shape) + + self.key_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self.num_key_value_heads, + head_dim, + ), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="key", + ) + self.key_dense.build(inputs_shape) + + self.value_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self.num_key_value_heads, + head_dim, + ), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="value", + ) + self.value_dense.build(inputs_shape) + + self.softmax = keras.layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) + + self.dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + + self.output_dense = keras.layers.EinsumDense( + equation="bquh,uhm->bqm", + output_shape=(None, hidden_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self.output_dense.build((None, None, self.num_query_heads, head_dim)) + + if self.rope_scaling_type is None: + self.rotary_embedding_layer = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + dtype=self.dtype_policy, + ) + elif self.rope_scaling_type == "su": + if len(self.rope_scaling_short_factor) != head_dim // 2: + raise ValueError( + "`rope_scaling_short_factor` must be of length " + "`hidden_dim//num_query_heads//2`. " + "`len(rope_scaling_short_factor)` is " + f"{len(self.rope_scaling_short_factor)} " + f"while it should be {head_dim // 2}." + ) + if len(self.rope_scaling_long_factor) != head_dim // 2: + raise ValueError( + "`rope_scaling_long_factor` must be of length " + "`hidden_dim//num_query_heads//2`. " + "`len(rope_scaling_long_factor)` is " + f"{len(self.rope_scaling_long_factor)} " + f"while it should be {head_dim // 2}." + ) + self.rotary_embedding_layer = Phi3SuScaledRotaryEmbedding( + inverese_freq_short_factor=self.rope_scaling_short_factor, + inverese_freq_long_factor=self.rope_scaling_long_factor, + max_sequence_length=self.max_sequence_length, + pretraining_sequence_length=self.pretraining_sequence_length, + max_wavelength=self.rope_max_wavelength, + dtype=self.dtype_policy, + ) + else: + raise ValueError( + '`rope_scaling_type` must be `None` or `"su"`.' + "if `None` is choosed, `RotaryEmbedding` will be used." + 'if `"su"` is choosed, `Phi3SuScaledRotaryEmbedding` will be ' + "used." + ) + + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + start_index = ( + cache_update_index if cache_update_index is not None else 0 + ) + + query = self.query_dense(hidden_states) + key = self.key_dense(hidden_states) + value = self.value_dense(hidden_states) + + # Compute RoPE for queries + query = self.rotary_embedding_layer(query, start_index=start_index) + key = self.rotary_embedding_layer(key, start_index=start_index) + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key) + value = ops.slice_update(value_cache, start, value) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + + # [batch_shape, seq_len, num_key_value_heads, head_dim] + # -> [batch_shape, seq_len, num_heads, head_dim] + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) + + attention_output = self._compute_attention( + query, key, value, attention_mask + ) + + attention_output = self.dropout_layer( + attention_output, training=training + ) + + attention_output = self.output_dense(attention_output) + + if cache is not None: + return attention_output, cache + return attention_output + + def _masked_softmax(self, attention_scores, attention_mask=None): + if attention_mask is not None: + return self.softmax(attention_scores, attention_mask[:, None, :, :]) + return self.softmax(attention_scores) + + def _compute_attention(self, query, key, value, attention_mask=None): + attention_scores = ops.einsum("bquh,bkuh->buqk", query, key) + attention_scores = attention_scores / self._norm_factor + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + attention_scores = ops.cast(attention_scores, self.compute_dtype) + attention_output = ops.einsum( + "buqk,bkuh->bquh", attention_scores, value + ) + + return attention_output + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "pretraining_sequence_length": self.pretraining_sequence_length, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_type": self.rope_scaling_type, + "rope_scaling_short_factor": self.rope_scaling_short_factor, + "rope_scaling_long_factor": self.rope_scaling_long_factor, + } + ) + return config diff --git a/keras_nlp/src/models/phi3/phi3_backbone.py b/keras_nlp/src/models/phi3/phi3_backbone.py new file mode 100644 index 0000000000..65a358213d --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_backbone.py @@ -0,0 +1,223 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.backend import keras +from keras_nlp.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.phi3.phi3_decoder import Phi3Decoder +from keras_nlp.src.models.phi3.phi3_layernorm import Phi3LayerNorm + + +def _phi3_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_nlp_export("keras_nlp.models.Phi3Backbone") +class Phi3Backbone(Backbone): + """Phi-3 core network with hyperparameters. + + This network implements a Transformer-based decoder network, + Phi-3, as described in ["Phi-3 Technical Report"](https://arxiv.org/pdf/2404.14219.pdf). + It includes the embedding lookups and transformer layers. + + The default constructor gives a fully customizable, randomly initialized + phi-3 model with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the `from_preset` + constructor. + + Args: + vocabulary_size (int): The size of the token vocabulary. + num_layers (int): The number of transformer layers. + hidden_dim (int): The size of the embeddings and the hidden states of + the transformer layers. + intermediate_dim (int): The output dimension of the first Dense layer in + a three-layer feedforward network for each transformer. + num_query_heads (int): The number of query attention heads for each + transformer layer. + num_key_value_heads (int): The number of key and value attention heads + for each transformer layer. + layer_norm_epsilon (float, optional): Epsilon for the RMS layernorm + layers in the transformer decoder. Defaults to `1e-6`. + dropout: (float, optional): Dropout probability for the Transformer + decoder. + max_sequence_length (int, optional): The maximum sequence length + that this model might ever be used with. Defaults to `4096`. + pretraining_sequence_length (int, optional): The maximum sequence length + that the model was pretrained with. Defaults to `4096`. + rope_max_wavelength (int, optional): The maximum angular wavelength of + the sine/cosine curves, for rotary embeddings. Defaults to `10000`. + rope_scaling_type (str, optional): The type of the rope scaling. Can be + either `None` or `"su"`. `None` is for no rope scaling, `"su"` is + for SuScaled rope, `"su"` is used when `max_sequence_length` is + larger than `original_max_sequence_length`. Defaults to `None`. + rope_scaling_short_factor List[float]: List of factors used to adjust + rope frequencies when the `rope_scaling_type` is `"su"`. List must + be of length `hidden_dim//num_query_heads//2`. It is used when + `sequence_length` is smaller than `original_max_sequence_length`. + Defaults to `None`. + rope_scaling_long_factor List[float]: List of factors used to adjust + rope frequencies when the `rope_scaling_type` is `"su"`. List must + be of length `hidden_dim//num_query_heads//2`. It is used when + `sequence_length` is larger than `original_max_sequence_length`. + Defaults to `None`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + + Examples: + + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Phi3 decoder. + model = keras_nlp.models.Phi3Backbone.from_preset( + "phi3_mini_4k_instruct_en" + ) + model(input_data) + + # Randomly initialized Phi3 decoder with custom config. + model = keras_nlp.models.Phi3Backbone( + vocabulary_size=10, + num_layers=2, + hidden_dim=512, + intermediate_dim=1024, + num_query_heads=32, + num_key_value_heads=8, + layer_norm_epsilon=1e-6, + dtype="float32" + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + hidden_dim, + intermediate_dim, + num_query_heads, + num_key_value_heads, + layer_norm_epsilon=1e-6, + dropout=0.0, + max_sequence_length=4096, + pretraining_sequence_length=4096, + rope_max_wavelength=10000, + rope_scaling_type=None, + rope_scaling_short_factor=None, + rope_scaling_long_factor=None, + dtype=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=False, + embeddings_initializer=_phi3_kernel_initializer(stddev=0.01), + dtype=dtype, + name="token_embedding", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = Phi3Decoder( + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + rope_max_wavelength=rope_max_wavelength, + layer_norm_epsilon=layer_norm_epsilon, + activation="silu", + kernel_initializer=_phi3_kernel_initializer(stddev=0.02), + dropout=dropout, + max_sequence_length=max_sequence_length, + pretraining_sequence_length=pretraining_sequence_length, + rope_scaling_type=rope_scaling_type, + rope_scaling_short_factor=rope_scaling_short_factor, + rope_scaling_long_factor=rope_scaling_long_factor, + dtype=dtype, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = Phi3LayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="sequence_output_layernorm", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, decoder_padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.rope_scaling_type = rope_scaling_type + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.max_sequence_length = max_sequence_length + self.pretraining_sequence_length = pretraining_sequence_length + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_type = rope_scaling_type + self.rope_scaling_short_factor = rope_scaling_short_factor + self.rope_scaling_long_factor = rope_scaling_long_factor + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_key_value_heads": self.num_key_value_heads, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "pretraining_sequence_length": self.pretraining_sequence_length, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_type": self.rope_scaling_type, + "rope_scaling_short_factor": self.rope_scaling_short_factor, + "rope_scaling_long_factor": self.rope_scaling_long_factor, + } + ) + return config diff --git a/keras_nlp/src/models/phi3/phi3_backbone_test.py b/keras_nlp/src/models/phi3/phi3_backbone_test.py new file mode 100644 index 0000000000..42f2c07e3c --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_backbone_test.py @@ -0,0 +1,105 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.src.backend import ops +from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone +from keras_nlp.src.tests.test_case import TestCase + + +class Phi3Test(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_query_heads": 4, + "num_key_value_heads": 2, + "hidden_dim": 8, + "intermediate_dim": 8, + } + self.su_rotary_init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_query_heads": 2, + "num_key_value_heads": 1, + "hidden_dim": 8, + "intermediate_dim": 12, + "max_sequence_length": 10, + "pretraining_sequence_length": 5, + "rope_scaling_type": "su", + "rope_scaling_short_factor": [1.2, 1.4], + "rope_scaling_long_factor": [0.8, 0.6], + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=Phi3Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 8), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Phi3Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_backbone_basics_with_su_rotary(self): + self.run_backbone_test( + cls=Phi3Backbone, + init_kwargs=self.su_rotary_init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 8), + ) + + @pytest.mark.large + def test_saved_model_with_su_rotary(self): + self.run_model_saving_test( + cls=Phi3Backbone, + init_kwargs=self.su_rotary_init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_smallest_preset(self): + self.run_preset_test( + cls=Phi3Backbone, + preset="phi3_mini_4k_instruct_en", + input_data={ + "token_ids": ops.array([[1, 450, 4996, 1701, 29916, 29889]]), + "padding_mask": ops.ones((1, 6), dtype="int32"), + }, + expected_output_shape=(1, 6, 3072), + # The forward pass from a preset should be stable! + # Reference values computed using PyTorch HF model. + expected_partial_output=ops.array( + [-0.21222, 0.04004, -0.02759, 0.02200] + ), + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Phi3Backbone.presets: + self.run_preset_test( + cls=Phi3Backbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/phi3/phi3_causal_lm.py b/keras_nlp/src/models/phi3/phi3_causal_lm.py new file mode 100644 index 0000000000..6835b009ad --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_causal_lm.py @@ -0,0 +1,217 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.backend import ops +from keras_nlp.src.models.causal_lm import CausalLM +from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone +from keras_nlp.src.models.phi3.phi3_causal_lm_preprocessor import ( + Phi3CausalLMPreprocessor, +) +from keras_nlp.src.utils.python_utils import classproperty +from keras_nlp.src.utils.tensor_utils import any_equal + + +@keras_nlp_export("keras_nlp.models.Phi3CausalLM") +class Phi3CausalLM(CausalLM): + """An end-to-end Phi3 model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a Phi-3 model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"top_k"` sampling will be used. + + Args: + backbone: A `keras_nlp.models.Phi3Backbone` instance. + preprocessor: A `keras_nlp.models.Phi3CausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + """ + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.inputs + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + @classproperty + def backbone_cls(cls): + return Phi3Backbone + + @classproperty + def preprocessor_cls(cls): + return Phi3CausalLMPreprocessor + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `Phi3CausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + # Each decoder layer has a cache; we update them separately. + updated_cache = [] + for i in range(self.backbone.num_layers): + current_cache = cache[:, i, ...] + x, next_cache = self.backbone.transformer_layers[i]( + x, + attention_cache=current_cache, + attention_cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + cache = ops.stack(updated_cache, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + stop_token_ids: Tuple of id's of the end token to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate(self, inputs, max_length=None, stop_token_ids="auto"): + if self.preprocessor and stop_token_ids == "auto": + # Stop at: + # `<|endoftext|>` (end of sequence token). + # `<|end|>` (end of turn token). + stop_token_ids = [self.preprocessor.tokenizer.end_token_id] + end_of_turn_id = self.preprocessor.tokenizer.token_to_id("<|end|>") + if end_of_turn_id != 0: + # If `<|end|>` exists in the vocabulary. + stop_token_ids.append(end_of_turn_id) + + return super().generate(inputs, max_length, stop_token_ids) diff --git a/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor.py b/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor.py new file mode 100644 index 0000000000..c63ed508fb --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor.py @@ -0,0 +1,191 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import tensorflow as tf +except ImportError: + raise ImportError( + "To use `keras_nlp`, please install Tensorflow: `pip install tensorflow`. " + "The TensorFlow package is required for data preprocessing with any backend." + ) +from absl import logging + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.backend import ops +from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor +from keras_nlp.src.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.src.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.Phi3CausalLMPreprocessor") +class Phi3CausalLMPreprocessor(Phi3Preprocessor): + """Phi3 Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.Phi3CausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.Phi3CausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.Phi3Tokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.Phi3CausalLMPreprocessor.from_preset( + "phi3_mini_4k_instruct_en" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("League of legends") + preprocessor(sentence) + # Same output. + preprocessor("League of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) + preprocessor(sentences) + # Same output. + preprocessor(["Taco tuesday", "Fish taco please!"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`Phi3CausalLMPreprocessor` generates `y` and " + "`sample_weight` based on your input data, but your data " + "already contains `y` or `sample_weight`. Your `y` and " + "`sample_weight` will be ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + # Convert the inputs to numpy arrays if they aren't a tensor already. + if not isinstance(token_ids, tf.Tensor): + token_ids = ops.convert_to_numpy(token_ids) + # Make sure the numpy array has type `int32` since + # `SentencePieceProcessor.detokenize` only accepts `int32` arrays. + token_ids = token_ids.astype("int32") + if not isinstance(padding_mask, tf.Tensor): + padding_mask = ops.convert_to_numpy(padding_mask) + padding_mask = padding_mask.astype("bool") + # Strip any special tokens during detokenization (e.g. the start and + # end markers). In the future we could make this configurable. + padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) + padding_mask = padding_mask & ( + token_ids != self.tokenizer.start_token_id + ) + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor_test.py b/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..b88ba11301 --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor_test.py @@ -0,0 +1,97 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.src.models.phi3.phi3_causal_lm_preprocessor import ( + Phi3CausalLMPreprocessor, +) +from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer +from keras_nlp.src.tests.test_case import TestCase + + +class Phi3CausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = Phi3Tokenizer( + # Generated using create_phi3_test_proto.py + proto=os.path.join(self.get_test_data_dir(), "phi3_test_vocab.spm") + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 10, + } + # [3, 5, 6, 4, 3, 9, 7, 11] + self.input_data = (["the fox"],) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=Phi3CausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 5, 6, 4, 3, 9, 7, 11, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0]], + }, + [[3, 5, 6, 4, 3, 9, 7, 11, 0, 0]], # Pass through labels. + [ + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0] + ], # Pass through sample_weights. + ), + ) + + def test_no_start_end_token(self): + input_data = ["the fox"] * 4 + + preprocessor = Phi3CausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], [[3, 5, 6, 4, 3, 9, 7, 11, 0, 0]] * 4 + ) + self.assertAllEqual( + x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 4 + ) + self.assertAllEqual(y, [[5, 6, 4, 3, 9, 7, 11, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "the fox" + preprocessor = Phi3CausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 3, 5, 6, 4, 3, 9, 7, 11, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 3, 5, 6, 4, 3, 9, 7, 11, 0], + "padding_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + } + preprocessor = Phi3CausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the fox") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Phi3CausalLMPreprocessor.presets: + self.run_preset_test( + cls=Phi3CausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/phi3/phi3_causal_lm_test.py b/keras_nlp/src/models/phi3/phi3_causal_lm_test.py new file mode 100644 index 0000000000..15740a20c1 --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_causal_lm_test.py @@ -0,0 +1,131 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch + +import pytest + +from keras_nlp.src.backend import ops +from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone +from keras_nlp.src.models.phi3.phi3_causal_lm import Phi3CausalLM +from keras_nlp.src.models.phi3.phi3_causal_lm_preprocessor import ( + Phi3CausalLMPreprocessor, +) +from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer +from keras_nlp.src.tests.test_case import TestCase + + +class Phi3CausalLMTest(TestCase): + def setUp(self): + self.preprocessor = Phi3CausalLMPreprocessor( + Phi3Tokenizer( + # Generated using create_phi3_test_proto.py + proto=os.path.join( + self.get_test_data_dir(), "phi3_test_vocab.spm" + ) + ), + sequence_length=12, + ) + self.vocab_size = self.preprocessor.tokenizer.vocabulary_size() + self.backbone = Phi3Backbone( + vocabulary_size=self.vocab_size, + num_layers=2, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=8, + intermediate_dim=16, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["the quick brown fox", "the earth is round"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=Phi3CausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 12, self.vocab_size), + ) + + def test_generate(self): + causal_lm = Phi3CausalLM(**self.init_kwargs) + # String input. + prompt = "the fox" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_early_stopping(self): + causal_lm = Phi3CausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the fox", "the earth"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = Phi3CausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + # @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Phi3CausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Phi3CausalLM.presets: + self.run_preset_test( + cls=Phi3CausalLM, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/phi3/phi3_decoder.py b/keras_nlp/src/models/phi3/phi3_decoder.py new file mode 100644 index 0000000000..134ce7d71b --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_decoder.py @@ -0,0 +1,259 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.backend import keras +from keras_nlp.src.backend import ops +from keras_nlp.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_nlp.src.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_nlp.src.models.phi3.phi3_attention import Phi3Attention +from keras_nlp.src.models.phi3.phi3_layernorm import Phi3LayerNorm +from keras_nlp.src.utils.keras_utils import clone_initializer + + +class Phi3Decoder(keras.layers.Layer): + """A Transformer decoder layer for the Phi-3 backbone.""" + + def __init__( + self, + hidden_dim, + intermediate_dim, + num_query_heads, + num_key_value_heads, + activation="silu", + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + dropout=0, + max_sequence_length=4096, + pretraining_sequence_length=4096, + rope_max_wavelength=10000, + rope_scaling_type=None, + rope_scaling_short_factor=None, + rope_scaling_long_factor=None, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + + self.max_sequence_length = max_sequence_length + self.pretraining_sequence_length = pretraining_sequence_length + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_type = rope_scaling_type + self.rope_scaling_short_factor = rope_scaling_short_factor + self.rope_scaling_long_factor = rope_scaling_long_factor + + self.dropout = dropout + + self.layer_norm_epsilon = layer_norm_epsilon + self.activation = keras.activations.get(activation) + self.kernel_initializer = keras.initializers.get(kernel_initializer) + + def build(self, decoder_sequence_shape): + + # Pre-attention layernorm. + self.pre_attention_layernorm = Phi3LayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="pre_attention_layernorm", + ) + self.pre_attention_layernorm.build(decoder_sequence_shape) + + # Self attention layer. + self.attention = Phi3Attention( + num_query_heads=self.num_query_heads, + num_key_value_heads=self.num_key_value_heads, + kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout, + max_sequence_length=self.max_sequence_length, + pretraining_sequence_length=self.pretraining_sequence_length, + rope_max_wavelength=self.rope_max_wavelength, + rope_scaling_type=self.rope_scaling_type, + rope_scaling_short_factor=self.rope_scaling_short_factor, + rope_scaling_long_factor=self.rope_scaling_long_factor, + dtype=self.dtype_policy, + name="attention", + ) + self.attention.build(decoder_sequence_shape) + + # Post-attention layernorm. + self.post_attention_layernorm = Phi3LayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="post_attention_layernorm", + ) + self.post_attention_layernorm.build(decoder_sequence_shape) + + # feedforward layers. + self.feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_intermediate_dense", + ) + self.feedforward_intermediate_dense.build(decoder_sequence_shape) + + self.feedforward_gate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_gate_dense", + ) + self.feedforward_gate_dense.build(decoder_sequence_shape) + + self.feedforward_output_dense = keras.layers.Dense( + self.hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_output_dense", + ) + + self.feedforward_output_dense.build( + self.feedforward_gate_dense.compute_output_shape( + decoder_sequence_shape + ) + ) + + # Dropout + self.attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="attention_dropout", + ) + self.feedforward_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="feedforward_dropout", + ) + + self.built = True + + def call( + self, + decoder_sequence, + decoder_padding_mask=None, + decoder_attention_mask=None, + attention_cache=None, + attention_cache_update_index=None, + ): + self_attention_mask = self._compute_self_attention_mask( + decoder_sequence=decoder_sequence, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + attention_cache=attention_cache, + attention_cache_update_index=attention_cache_update_index, + ) + residual = decoder_sequence + x = self.pre_attention_layernorm(decoder_sequence) + x = self.attention( + hidden_states=x, + attention_mask=self_attention_mask, + cache=attention_cache, + cache_update_index=attention_cache_update_index, + ) + if attention_cache is not None: + x, attention_cache = x + x = self.attention_dropout(x) + x = x + residual + + residual = x + x = self.post_attention_layernorm(x) + # Note that we run the activation function in full 32-bit + # precision since this is what `torch.nn.functional.silu` + # does. Internally, `torch.nn.functional.silu` converts the + # inputs to float32, computes SiLU, and converts the outputs + # back to compute dtype. + # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 + # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 + gate_output = self.feedforward_gate_dense(x) + gate_output = ops.cast(gate_output, "float32") + gate_output = self.activation(gate_output) + gate_output = ops.cast(gate_output, self.compute_dtype) + x = self.feedforward_intermediate_dense(x) + x = self.feedforward_output_dense(ops.multiply(x, gate_output)) + x = self.feedforward_dropout(x) + decoder_output = x + residual + + if attention_cache is not None: + return decoder_output, attention_cache + return decoder_output + + def _compute_self_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + attention_cache, + attention_cache_update_index, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if attention_cache is not None: + input_length = ops.shape(attention_cache)[2] + + cache_update_index = ( + 0 + if attention_cache_update_index is None + else attention_cache_update_index + ) + + causal_mask = compute_causal_mask( + batch_size, input_length, output_length, cache_update_index + ) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def compute_output_shape(self, decoder_sequence_shape): + return decoder_sequence_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "activation": keras.activations.serialize(self.activation), + "layer_norm_epsilon": self.layer_norm_epsilon, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "pretraining_sequence_length": self.pretraining_sequence_length, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_type": self.rope_scaling_type, + "rope_scaling_short_factor": self.rope_scaling_short_factor, + "rope_scaling_long_factor": self.rope_scaling_long_factor, + } + ) + return config diff --git a/keras_nlp/src/models/phi3/phi3_layernorm.py b/keras_nlp/src/models/phi3/phi3_layernorm.py new file mode 100644 index 0000000000..ac2d3addd4 --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_layernorm.py @@ -0,0 +1,48 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.backend import keras +from keras_nlp.src.backend import ops + + +# TODO: Deprecate this in favor of +# `keras.layers.LayerNormalization(rms_scaling=True)` once Keras 2 support is +# removed. +class Phi3LayerNorm(keras.layers.Layer): + """A normalization layer for Phi-3 that implements RMS normalization.""" + + def __init__(self, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + + def build(self, input_shape): + dim = input_shape[-1] + self.scale = self.add_weight( + name="scale", + trainable=True, + shape=(dim,), + initializer="ones", + dtype=self.variable_dtype, + ) + self.built = True + + def call(self, x): + x = ops.cast(x, "float32") + var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + x = x * ops.rsqrt(var + self.epsilon) + return ops.cast(x, self.compute_dtype) * self.scale + + def get_config(self): + config = super().get_config() + config.update({"epsilon": self.epsilon}) + return config diff --git a/keras_nlp/src/models/phi3/phi3_preprocessor.py b/keras_nlp/src/models/phi3/phi3_preprocessor.py new file mode 100644 index 0000000000..3ef58ec886 --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_preprocessor.py @@ -0,0 +1,189 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer +from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.src.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.Phi3Preprocessor") +class Phi3Preprocessor(Preprocessor): + """A Phi3 preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do three things: + + 1. Tokenize any number of input segments using the `tokenizer`. + 2. Pack the inputs together using a `keras_nlp.layers.StartEndPacker`. + with the appropriate tokens. + 3. Construct a dictionary with keys `"token_ids"`, and `"padding_mask"` + that can be passed directly to `keras_nlp.models.Phi3Backbone`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + Args: + tokenizer: A `keras_nlp.models.Phi3Tokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A tensor of single string sequences, or a tuple of multiple + tensor sequences to be packed together. Inputs may be batched or + unbatched. For single sequences, raw python inputs will be converted + to tensors. For multiple sequences, pass tensors directly. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + + Directly calling the from_preset(). + ```python + preprocessor = keras_nlp.models.Phi3Preprocessor.from_preset( + "" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize and a batch of single sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Preprocess a batch of sentence pairs. + # When handling multiple sequences, always convert to tensors first! + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + preprocessor((first, second)) + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.Phi3Preprocessor.from_preset( + "phi3_mini_4k_instruct_en" + ) + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + label = tf.constant([1, 1]) + + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((first, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(first) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map labeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices(((first, second), label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices((first, second)) + + # Watch out for tf.data's default unpacking of tuples here! + # Best to invoke the `preprocessor` directly in this case. + ds = ds.map( + lambda first, second: preprocessor(x=(first, second)), + num_parallel_calls=tf.data.AUTOTUNE, + ) + ``` + """ + + tokenizer_cls = Phi3Tokenizer + + def __init__( + self, + tokenizer, + sequence_length=4096, + add_start_token=True, + add_end_token=False, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = None + self.add_start_token = add_start_token + self.add_end_token = add_end_token + self.sequence_length = sequence_length + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "Phi3 requires each input feature to contain only " + f"one segment, but received {len(x)}. If you are using Phi3" + " for a multi-segment classification task, please refer to " + "classification models like BERT or RoBERTa." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return pack_x_y_sample_weight(x, y, sample_weight) + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value diff --git a/keras_nlp/src/models/phi3/phi3_preprocessor_test.py b/keras_nlp/src/models/phi3/phi3_preprocessor_test.py new file mode 100644 index 0000000000..406fc4eb17 --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_preprocessor_test.py @@ -0,0 +1,69 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor +from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer +from keras_nlp.src.tests.test_case import TestCase + + +class Phi3PreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = Phi3Tokenizer( + # Generated using create_phi3_test_proto.py + proto=os.path.join(self.get_test_data_dir(), "phi3_test_vocab.spm") + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 12, + } + self.input_data = ( + # Encoded to [3, 5, 6, 4, 3, 9, 7, 11, 3, 15] + ["the fox <|endoftext|>"], + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=Phi3Preprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 5, 6, 4, 3, 9, 7, 11, 3, 15, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]], + }, + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ), + ) + + def test_errors_for_2d_list_input(self): + preprocessor = Phi3Preprocessor(**self.init_kwargs) + ambiguous_input = [["one", "two"], ["three", "four"]] + with self.assertRaises(ValueError): + preprocessor(ambiguous_input) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Phi3Preprocessor.presets: + self.run_preset_test( + cls=Phi3Preprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/phi3/phi3_presets.py b/keras_nlp/src/models/phi3/phi3_presets.py new file mode 100644 index 0000000000..48ea0c1994 --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_presets.py @@ -0,0 +1,50 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Phi-3 model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = { + "phi3_mini_4k_instruct_en": { + "metadata": { + "description": ( + "3.8 billion parameters, 32 layers, 4k context length, Phi-3 " + "model. The model was trained using the Phi-3 datasets. This " + "dataset includes both synthetic data and filtered publicly " + "available website data, with an emphasis on high-quality and " + "reasoning-dense properties.", + ), + "params": 3821079552, + "official_name": "Phi-3", + "path": "phi3", + "model_card": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", + }, + "kaggle_handle": "kaggle://keras/phi3/keras/phi3_mini_4k_instruct_en", + }, + "phi3_mini_128k_instruct_en": { + "metadata": { + "description": ( + "3.8 billion parameters, 32 layers, 128k context length, Phi-3 " + "model. The model was trained using the Phi-3 datasets. This " + "dataset includes both synthetic data and filtered publicly " + "available website data, with an emphasis on high-quality and " + "reasoning-dense properties." + ), + "params": 3821079552, + "official_name": "Phi-3", + "path": "phi3", + "model_card": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct", + }, + "kaggle_handle": "kaggle://keras/phi3/keras/phi3_mini_128k_instruct_en", + }, +} diff --git a/keras_nlp/src/models/phi3/phi3_rotary_embedding.py b/keras_nlp/src/models/phi3/phi3_rotary_embedding.py new file mode 100644 index 0000000000..ff628b3dd7 --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_rotary_embedding.py @@ -0,0 +1,136 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras_nlp.src.backend import ops +from keras_nlp.src.layers.modeling.rotary_embedding import RotaryEmbedding + + +class Phi3SuScaledRotaryEmbedding(RotaryEmbedding): + """SuRotary positional encoding layer. + + Args: + inverese_freq_short_factor List[float]: List of factors used to adjust + rope frequencies when the `rope_scaling_type` is `"su"`. List must + be of length `hidden_dim//num_query_heads//2`. It is used when + `sequence_length` is smaller than `original_max_sequence_length`. + inverese_freq_long_factor List[float]: List of factors used to adjust + rope frequencies when the `rope_scaling_type` is `"su"`. List must + be of length `hidden_dim//num_query_heads//2`. It is used when + `sequence_length` is larger than `original_max_sequence_length`. + max_sequence_length: int. The maximum sequence length that this + model might ever be used with. + pretraining_sequence_length: int. The maximum sequence length that + this model was pretrained with. + max_wavelength: int. The maximum angular wavelength of the sine/cosine + curves. + + Call arguments: + inputs: The tensor inputs to apply the embedding to. This can have + any shape, but must contain both a sequence and feature axis. The + rotary embedding will be applied to `inputs` and returned. + start_index: An integer or integer tensor. The starting position to + compute the rotary embedding from. This is useful during cached + decoding, where each position is predicted separately in a loop. + + References: + - [Phi-3-mini-128k-instruct original implementation](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/blob/0693e0b867d29e7318280ddd3ff9d5e66698f488/modeling_phi3.py#L142) + """ + + def __init__( + self, + inverese_freq_short_factor, + inverese_freq_long_factor, + max_sequence_length=4096, + pretraining_sequence_length=4096, + max_wavelength=10000, + **kwargs + ): + super().__init__(max_wavelength=max_wavelength, **kwargs) + self.max_sequence_length = max_sequence_length + self.pretraining_sequence_length = pretraining_sequence_length + + scaling_factor = ( + self.max_sequence_length / self.pretraining_sequence_length + ) + if scaling_factor <= 1.0: + self.embedding_scaling_factor = 1.0 + else: + self.embedding_scaling_factor = math.sqrt( + 1 + + math.log(scaling_factor) + / math.log(self.pretraining_sequence_length) + ) + + self.inverese_freq_short_factor = inverese_freq_short_factor + self.inverese_freq_long_factor = inverese_freq_long_factor + + def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None): + feature_axis = len(inputs.shape) - 1 + sequence_axis = 1 + + rotary_dim = ops.shape(inputs)[feature_axis] + inverse_freq = self._get_inverse_freq(rotary_dim) + + # Multiply inverse_freq by a factor. + if ops.shape(inputs)[sequence_axis] > self.pretraining_sequence_length: + inverse_freq = ops.divide( + inverse_freq, + ops.convert_to_tensor(self.inverese_freq_long_factor), + ) + else: + inverse_freq = ops.divide( + inverse_freq, + ops.convert_to_tensor(self.inverese_freq_short_factor), + ) + + if positions is None: + positions = self._compute_positions(inputs, start_index) + else: + positions = ops.cast(positions, "float32") + + freq = ops.einsum("i,j->ij", positions, inverse_freq) + embedding = ops.stack((freq, freq), axis=-2) + embedding = ops.reshape( + embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2) + ) + + # Reshape the embedding to be broadcastable with input shape. + if feature_axis < sequence_axis: + embedding = ops.transpose(embedding) + for axis in range(len(inputs.shape)): + if axis != sequence_axis and axis != feature_axis: + embedding = ops.expand_dims(embedding, axis) + + cos_emb = ops.cast( + ops.cos(embedding) * self.embedding_scaling_factor, + self.compute_dtype, + ) + sin_emb = ops.cast( + ops.sin(embedding) * self.embedding_scaling_factor, + self.compute_dtype, + ) + return cos_emb, sin_emb + + def get_config(self): + config = super().get_config() + config.update( + { + "max_sequence_length": self.max_sequence_length, + "pretraining_sequence_length": self.pretraining_sequence_length, + "inverese_freq_short_factor": self.inverese_freq_short_factor, + "inverese_freq_long_factor": self.inverese_freq_long_factor, + } + ) + return config diff --git a/keras_nlp/src/models/phi3/phi3_tokenizer.py b/keras_nlp/src/models/phi3/phi3_tokenizer.py new file mode 100644 index 0000000000..d45201ff6b --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_tokenizer.py @@ -0,0 +1,94 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.phi3.phi3_presets import backbone_presets +from keras_nlp.src.tokenizers.sentence_piece_tokenizer import ( + SentencePieceTokenizer, +) +from keras_nlp.src.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.Phi3Tokenizer") +class Phi3Tokenizer(SentencePieceTokenizer): + """Phi3 tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_nlp.tokenizers.SentencePieceTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by + Phi3 models and provides a `from_preset()` method to automatically + download a matching vocabulary for a Phi3 preset. + + This tokenizer does not provide truncation or padding of inputs. It can be + combined with a `keras_nlp.models.Phi3Preprocessor` layer for input + packing. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, or a + `bytes` object with a serialized SentencePiece proto. See the + [SentencePiece repository](https://github.com/google/sentencepiece) + for more details on the format. + + Examples: + ```python + # Unbatched input. + tokenizer = keras_nlp.models.Phi3Tokenizer.from_preset( + "phi3_mini_4k_instruct_en", + ) + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + ``` + """ + + def __init__(self, proto, **kwargs): + self.start_token = "" + self.end_token = "<|endoftext|>" + super().__init__(proto=proto, **kwargs) + + def set_proto(self, proto): + super().set_proto(proto) + if proto is not None: + for token in [self.start_token, self.end_token]: + if token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your " + "`vocabulary` or use a pretrained `vocabulary` name." + ) + self.start_token_id = self.token_to_id(self.start_token) + self.end_token_id = self.token_to_id(self.end_token) + # TODO: `pad_token` is `<|endoftext|>`, but setting it to `` + # for now, because of the way sampler works. sampler will think that + # `pad_token` is `end_token` and stop generation immediatly. + self.pad_token_id = 0 + else: + self.start_token_id = None + self.end_token_id = None + self.pad_token_id = None + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/src/models/phi3/phi3_tokenizer_test.py b/keras_nlp/src/models/phi3/phi3_tokenizer_test.py new file mode 100644 index 0000000000..2c823acfb7 --- /dev/null +++ b/keras_nlp/src/models/phi3/phi3_tokenizer_test.py @@ -0,0 +1,81 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer +from keras_nlp.src.tests.test_case import TestCase + + +class Phi3TokenizerTest(TestCase): + def setUp(self): + self.init_kwargs = { + # Generated using create_phi3_test_proto.py + "proto": os.path.join( + self.get_test_data_dir(), "phi3_test_vocab.spm" + ) + } + # `<|endoftext|>` id = vocab_size = 15 + self.input_data = [ + "the fox <|endoftext|>", + "the earth <|endoftext|>", + ] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=Phi3Tokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[ + [3, 5, 6, 4, 3, 9, 7, 11, 3, 15], + [3, 5, 6, 4, 3, 4, 8, 14, 5, 6, 3, 15], + ], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + Phi3Tokenizer( + # Generated using create_no_special_token_proto.py + proto=os.path.join( + self.get_test_data_dir(), "no_special_token_vocab.spm" + ) + ) + # Llama proto doesn't have `<|endoftext|>` + with self.assertRaises(ValueError): + Phi3Tokenizer( + # Generated using create_no_special_token_proto.py + proto=os.path.join( + self.get_test_data_dir(), "llama_test_vocab.spm" + ) + ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=Phi3Tokenizer, + preset="phi3_mini_4k_instruct_en", + input_data=["The quick brown fox."], + expected_output=[[450, 4996, 17354, 1701, 29916, 29889]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Phi3Tokenizer.presets: + self.run_preset_test( + cls=Phi3Tokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/tests/test_data/phi3_test_vocab.spm b/keras_nlp/src/tests/test_data/phi3_test_vocab.spm new file mode 100644 index 0000000000..9d2bb0fd9e Binary files /dev/null and b/keras_nlp/src/tests/test_data/phi3_test_vocab.spm differ diff --git a/tools/checkpoint_conversion/convert_phi3_checkpoints.py b/tools/checkpoint_conversion/convert_phi3_checkpoints.py new file mode 100644 index 0000000000..0ec3fd4d2c --- /dev/null +++ b/tools/checkpoint_conversion/convert_phi3_checkpoints.py @@ -0,0 +1,445 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import gc +import json +import os +import re +import sys + +os.environ["KERAS_BACKEND"] = "torch" + +import huggingface_hub # noqa: E402 +import keras # noqa: E402 +import torch # noqa: E402 +import transformers # noqa: E402 + +from keras_nlp import upload_preset # noqa: E402 +from keras_nlp.src.models import Phi3Backbone # noqa: E402 +from keras_nlp.src.models import Phi3Preprocessor # noqa: E402 +from keras_nlp.src.models import Phi3Tokenizer # noqa: E402 + +PRESET_MAP = { + "phi3_mini_4k_instruct_en": "microsoft/Phi-3-mini-4k-instruct", + "phi3_mini_128k_instruct_en": "microsoft/Phi-3-mini-128k-instruct", +} + + +def download_hf_model(hf_model_name, extract_dir): + hf_model_dir = huggingface_hub.snapshot_download( + repo_id=hf_model_name, + allow_patterns=["*.json", "*.safetensors", "*.py", "*.model"], + ignore_patterns=["*/*"], + local_dir=extract_dir, + ) + + return hf_model_dir + + +def convert_tokenizer(hf_model_dir): + # We can't import `sentencepiece_model_pb2` from `sentencepiece` because of + # this protobuffer lib error + # `TypeError: Couldn't build proto file into descriptor pool: duplicate + # file name sentencepiece_model.proto` + # transformers library build `sentencepiece_model.proto` once. we can't + # import again because it will try to build the proto again into the + # descriptor pool and protobuffer forbids that. + sp_pb2_module = sys.modules.get( + "transformers.utils.sentencepiece_model_pb2_new", None + ) + if sp_pb2_module is None: + sp_pb2_module = sys.modules.get( + "transformers.utils.sentencepiece_model_pb2", None + ) + if sp_pb2_module is None: + from sentencepiece import sentencepiece_model_pb2 as sp_pb2_module + model_path = os.path.join(hf_model_dir, "tokenizer.model") + added_tokens_path = os.path.join(hf_model_dir, "added_tokens.json") + with open(model_path, "rb") as sp_model_file: + model_proto = sp_pb2_module.ModelProto() + model_proto.ParseFromString(sp_model_file.read()) + with open(added_tokens_path, "rb") as added_tokens_file: + added_tokens = json.load(added_tokens_file) + # Add the new tokens to the model as user defined pieces. + for token in added_tokens.keys(): + new_token = sp_pb2_module.ModelProto().SentencePiece() + new_token.piece = token + new_token.score = 0.0 + new_token.type = 4 # user defined symbols. + model_proto.pieces.append(new_token) + tokenizer = Phi3Tokenizer(model_proto.SerializeToString()) + for key, value in added_tokens.items(): + assert key == tokenizer.id_to_token( + value + ), f"{key} token have different id in the tokenizer" + + return tokenizer + + +def convert_model(hf_model, device, dtype): + hf_config = hf_model.config.to_dict() + + kwargs = {} + kwargs["vocabulary_size"] = hf_config["vocab_size"] + kwargs["num_layers"] = hf_config["num_hidden_layers"] + kwargs["num_query_heads"] = hf_config["num_attention_heads"] + kwargs["num_key_value_heads"] = hf_config["num_key_value_heads"] + kwargs["hidden_dim"] = hf_config["hidden_size"] + kwargs["intermediate_dim"] = hf_config["intermediate_size"] + kwargs["dropout"] = hf_config["attention_dropout"] + kwargs["layer_norm_epsilon"] = hf_config["rms_norm_eps"] + kwargs["max_sequence_length"] = hf_config["max_position_embeddings"] + kwargs["original_max_sequence_length"] = hf_config[ + "original_max_position_embeddings" + ] + kwargs["rope_max_wavelength"] = hf_config["rope_theta"] + if hf_config["rope_scaling"] is not None: + kwargs["rope_scaling_type"] = hf_config["rope_scaling"]["type"] + kwargs["rope_scaling_short_factor"] = hf_config["rope_scaling"][ + "short_factor" + ] + kwargs["rope_scaling_long_factor"] = hf_config["rope_scaling"][ + "long_factor" + ] + kwargs["dtype"] = dtype + + with keras.device(device): + keras_model = Phi3Backbone(**kwargs) + + return keras_model + + +def convert_weights(keras_model, hf_model): + hidden_dim = keras_model.hidden_dim + intermediate_dim = keras_model.intermediate_dim + num_query_heads = keras_model.num_query_heads + num_key_value_heads = keras_model.num_key_value_heads + head_dim = hidden_dim // num_query_heads + + # get huggingface model weights. + hf_wts = hf_model.state_dict() + + # Embedding layer. + keras_model.token_embedding.embeddings.assign( + hf_wts["model.embed_tokens.weight"] + ) + keras_model.token_embedding.reverse_embeddings.assign( + hf_wts["lm_head.weight"].t() + ) + # LayerNorm. + keras_model.layer_norm.scale.assign(hf_wts["model.norm.weight"]) + + # Decoder layers. + for i, decoder_layer in enumerate(keras_model.transformer_layers): + # LayrNorm. + decoder_layer.pre_attention_layernorm.scale.assign( + hf_wts[f"model.layers.{i}.input_layernorm.weight"] + ) + decoder_layer.post_attention_layernorm.scale.assign( + hf_wts[f"model.layers.{i}.post_attention_layernorm.weight"] + ) + + # Attention layer. + attention_layer = decoder_layer.attention + fused_qkv_kernel = hf_wts[ + f"model.layers.{i}.self_attn.qkv_proj.weight" + ].t() + + query_kernel = fused_qkv_kernel[:, :hidden_dim] + query_kernel = query_kernel.reshape( + hidden_dim, num_query_heads, head_dim + ) + key_kernel = fused_qkv_kernel[ + :, hidden_dim : hidden_dim + num_key_value_heads * head_dim + ] + key_kernel = key_kernel.reshape( + hidden_dim, num_key_value_heads, head_dim + ) + value_kernel = fused_qkv_kernel[ + :, hidden_dim + num_key_value_heads * head_dim : + ] + value_kernel = value_kernel.reshape( + hidden_dim, num_key_value_heads, head_dim + ) + + attention_layer.query_dense._kernel.assign(query_kernel) + attention_layer.key_dense._kernel.assign(key_kernel) + attention_layer.value_dense._kernel.assign(value_kernel) + + attention_layer.output_dense.kernel.assign( + hf_wts[f"model.layers.{i}.self_attn.o_proj.weight"] + .t() + .reshape(num_query_heads, head_dim, hidden_dim) + ) + + # feed dorward layer. + fused_intermediate_gate_ff_kernel = hf_wts[ + f"model.layers.{i}.mlp.gate_up_proj.weight" + ].t() + decoder_layer.feedforward_gate_dense._kernel.assign( + fused_intermediate_gate_ff_kernel[:, :intermediate_dim] + ) + decoder_layer.feedforward_intermediate_dense._kernel.assign( + fused_intermediate_gate_ff_kernel[:, intermediate_dim:] + ) + decoder_layer.feedforward_output_dense._kernel.assign( + hf_wts[f"model.layers.{i}.mlp.down_proj.weight"].t() + ) + + +def validate_output( + hf_model, + keras_model, + hf_device, + keras_device, + hf_tokenizer, + keras_preprocessor, +): + # Hf + tokens = hf_tokenizer( + ["<|user|>How to win?<|end|><|assistant|>"], + max_length=9, + padding="max_length", + return_tensors="pt", + ) + + hf_model_input = { + "input_ids": tokens["input_ids"].to(hf_device), + "attention_mask": tokens["attention_mask"].to(hf_device), + "use_cache": False, + "output_attentions": False, + "output_hidden_states": False, + "return_dict": False, + } + + hf_model_outputs = hf_model(**hf_model_input)[0] + + # KerasNLP + keras_model_input = keras_preprocessor( + ["<|user|>How to win?<|end|><|assistant|>"] + ) + keras_model_input = { + k: v.to(keras_device) for k, v in keras_model_input.items() + } + keras_model_outputs = keras_model(keras_model_input) + + # Comparing the outputs. + print("🔶 KerasNLP output:", keras_model_outputs[0, 0, :10]) + print("🔶 HF output:", hf_model_outputs[0, 0, :10]) + print( + "🔶 Difference:", + torch.mean( + torch.abs( + keras_model_outputs.detach().cpu() + - hf_model_outputs.detach().cpu() + ) + ), + ) + + +def get_torch_dtype(str_dtype): + if str_dtype == "float32": + return torch.float32 + elif str_dtype == "float16": + return torch.float16 + elif str_dtype == "bfloat16": + return torch.bfloat16 + + +def convert_and_validate( + hf_model_dir, + hf_device, + keras_device, + validate_dtype, + hf_tokenizer, + keras_preprocessor, +): + print(f"✅ Numerics Validation in {validate_dtype}.") + # Load the causal model to convert lm_head weights. + hf_causal_model = transformers.AutoModelForCausalLM.from_pretrained( + hf_model_dir, + device_map=hf_device, + torch_dtype=get_torch_dtype(validate_dtype), + trust_remote_code=True, + ) + hf_model = hf_causal_model.model + print("✅ Huggingface model loaded.") + + keras_model = convert_model(hf_causal_model, keras_device, validate_dtype) + print("✅ Keras model loaded.") + + convert_weights(keras_model, hf_causal_model) + print("✅ Weights converted") + + validate_output( + hf_model, + keras_model, + hf_device, + keras_device, + hf_tokenizer, + keras_preprocessor, + ) + print("✅ Numerics validated") + + # Clean memory. + del keras_model + del hf_causal_model + del hf_model + gc.collect() + if not (hf_device == "cpu" and keras_device == "cpu"): + torch.cuda.empty_cache() + + +def convert_and_save( + preset, + hf_model_dir, + hf_device, + keras_device, + save_dtype, + keras_preprocessor, +): + print(f"✅ Saving model in {save_dtype}.") + # Load the causal model to convert lm_head weights. + hf_causal_model = transformers.AutoModelForCausalLM.from_pretrained( + hf_model_dir, + device_map=hf_device, + torch_dtype=get_torch_dtype(save_dtype), + trust_remote_code=True, + ) + print("✅ Huggingface model loaded.") + + keras_model = convert_model(hf_causal_model, keras_device, save_dtype) + print("✅ Keras model loaded.") + + convert_weights(keras_model, hf_causal_model) + print("✅ Weights converted") + + keras_model.save_to_preset(preset) + keras_preprocessor.tokenizer.save_to_preset(preset) + print("✅ Preset saved") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--preset", + default="phi3_mini_4k_instruct_en", + choices=PRESET_MAP.keys(), + required=True, + help=f'Preset must be one of {", ".join(PRESET_MAP.keys())}', + ) + + def device_regex(arg_value, pattern=re.compile(r"^cpu$|^cuda:[0-9]+$")): + if not pattern.match(arg_value): + raise argparse.ArgumentTypeError( + "The device must be one of: " + "'cpu', 'cuda:0', 'cuda:1', ...'cuda:n'" + ) + return arg_value + + parser.add_argument( + "--hf_device", + default="cpu", + type=device_regex, + help=( + "The device where huggingface model will be loaded. can be one of: " + "'cpu', 'cuda:0', 'cuda:1', ...'cuda:n'" + ), + ) + parser.add_argument( + "--keras_device", + default="cpu", + type=device_regex, + help=( + "The device where keras model will be loaded. can be one of: " + "'cpu', 'cuda:0', 'cuda:1', ...'cuda:n'" + ), + ) + parser.add_argument( + "--validate_dtype", + choices=["float32", "float16", "bfloat16"], + default="float32", + help=( + "The dtype of the two models while validating numerics. " + "can be 'float32', 'float16', or 'bfloat16'" + ), + ) + parser.add_argument( + "--save_dtype", + choices=["float32", "float16", "bfloat16"], + default="bfloat16", + help=( + "The dtype that keras model will be saved with. " + "can be 'float32', 'float16', or 'bfloat16'" + ), + ) + parser.add_argument( + "--upload_link", + type=str, + help=( + "The link to upload the model. can be in these formats: " + "`kaggle://///`, " + "`hf://[/]`" + ), + ) + + args = parser.parse_args() + preset = args.preset + hf_device = args.hf_device + keras_device = args.keras_device + validate_dtype = args.validate_dtype + save_dtype = args.save_dtype + upload_link = args.upload_link + + print(f"✅ Coverting {preset}.") + + hf_model_name = PRESET_MAP[preset] + hf_model_dir = download_hf_model(hf_model_name, f"./{preset}_hf_model") + print("✅ Huggingface model downloaded from the hub.") + + keras_tokenizer = convert_tokenizer(hf_model_dir) + keras_preprocessor = Phi3Preprocessor( + tokenizer=keras_tokenizer, sequence_length=9 + ) + # phi3 uses llama tokenizer + hf_tokenizer = transformers.LlamaTokenizer.from_pretrained( + hf_model_dir, padding_side="right" + ) + + convert_and_validate( + hf_model_dir=hf_model_dir, + hf_device=hf_device, + keras_device=keras_device, + validate_dtype=validate_dtype, + hf_tokenizer=hf_tokenizer, + keras_preprocessor=keras_preprocessor, + ) + + convert_and_save( + preset=preset, + hf_device=hf_device, + keras_device=keras_device, + save_dtype=save_dtype, + hf_model_dir=hf_model_dir, + keras_preprocessor=keras_preprocessor, + ) + + if upload_link is not None: + upload_preset(upload_link, preset) + print("✅ Preset uploaded") + + +if __name__ == "__main__": + main() diff --git a/tools/sentencepiece_testing/create_phi3_test_proto.py b/tools/sentencepiece_testing/create_phi3_test_proto.py new file mode 100644 index 0000000000..a9e02ff92e --- /dev/null +++ b/tools/sentencepiece_testing/create_phi3_test_proto.py @@ -0,0 +1,75 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib + +import sentencepiece.sentencepiece_model_pb2 as sp_pb2 + +from tools.sentencepiece_testing.utils import train_sentencepiece + +ADDED_TOKENS = [ + "<|endoftext|>", + "<|assistant|>", + "<|system|>", + "<|end|>", + "<|user|>", +] + + +def add_added_tokens(filename): + with open( + pathlib.Path(__file__).parent.parent.parent + / "keras_nlp" + / "src" + / "tests" + / "test_data" + / filename, + mode="rb", + ) as sp_model_file: + model_proto = sp_pb2.ModelProto() + model_proto.ParseFromString(sp_model_file.read()) + for token in ADDED_TOKENS: + new_token = sp_pb2.ModelProto().SentencePiece() + new_token.piece = token + new_token.score = 0.0 + new_token.type = 4 # user defined symbols. + model_proto.pieces.append(new_token) + with open( + pathlib.Path(__file__).parent.parent.parent + / "keras_nlp" + / "src" + / "tests" + / "test_data" + / filename, + mode="wb", + ) as f: + f.write(model_proto.SerializeToString()) + + +def main(): + train_sentencepiece( + ["the fox on the table", "the fox on the earth"], + "phi3_test_vocab.spm", + vocab_size=15, + model_type="bpe", # BPE + pad_id=-1, + unk_id=0, + bos_id=1, + eos_id=2, + ) + add_added_tokens("phi3_test_vocab.spm") + + +if __name__ == "__main__": + main() diff --git a/tools/sentencepiece_testing/utils.py b/tools/sentencepiece_testing/utils.py index 9deebd9737..caaeae7249 100644 --- a/tools/sentencepiece_testing/utils.py +++ b/tools/sentencepiece_testing/utils.py @@ -25,6 +25,7 @@ def train_sentencepiece(data, filename, *args, **kwargs): with open( pathlib.Path(__file__).parent.parent.parent / "keras_nlp" + / "src" / "tests" / "test_data" / filename,