-
Notifications
You must be signed in to change notification settings - Fork 305
Add a preprocessor for the Mistral backbone #1385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
tirthasheshpatel
merged 4 commits into
keras-team:master
from
tirthasheshpatel:mistral-preprocessor
Jan 5, 2024
Merged
Changes from 2 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| # Copyright 2023 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from keras_nlp.api_export import keras_nlp_export | ||
| from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker | ||
| from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer | ||
| from keras_nlp.models.preprocessor import Preprocessor | ||
| from keras_nlp.utils.keras_utils import ( | ||
| convert_inputs_to_list_of_tensor_segments, | ||
| ) | ||
| from keras_nlp.utils.keras_utils import pack_x_y_sample_weight | ||
| from keras_nlp.utils.python_utils import classproperty | ||
|
|
||
|
|
||
| @keras_nlp_export("keras_nlp.models.MistralPreprocessor") | ||
| class MistralPreprocessor(Preprocessor): | ||
| """An Mistral preprocessing layer which tokenizes and packs inputs. | ||
tirthasheshpatel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| This preprocessing layer will do three things: | ||
|
|
||
| 1. Tokenize any number of input segments using the `tokenizer`. | ||
| 2. Pack the inputs together using a `keras_nlp.layers.StartEndPacker`. | ||
| with the appropriate tokens. | ||
| 3. Construct a dictionary with keys `"token_ids"`, and `"padding_mask"` | ||
| that can be passed directly to `keras_nlp.models.MistralBackbone`. | ||
|
|
||
| This layer can be used directly with `tf.data.Dataset.map` to preprocess | ||
| string data in the `(x, y, sample_weight)` format used by | ||
| `keras.Model.fit`. | ||
|
|
||
| Args: | ||
| tokenizer: A `keras_nlp.models.MistralTokenizer` instance. | ||
| sequence_length: The length of the packed inputs. | ||
| add_start_token: If `True`, the preprocessor will prepend the tokenizer | ||
| start token to each input sequence. Default is `True`. | ||
| add_end_token: If `True`, the preprocessor will append the tokenizer | ||
| end token to each input sequence. Default is `False`. | ||
|
|
||
| Call arguments: | ||
| x: A tensor of single string sequences, or a tuple of multiple | ||
| tensor sequences to be packed together. Inputs may be batched or | ||
| unbatched. For single sequences, raw python inputs will be converted | ||
| to tensors. For multiple sequences, pass tensors directly. | ||
| y: Any label data. Will be passed through unaltered. | ||
| sample_weight: Any label weight data. Will be passed through unaltered. | ||
| sequence_length: Pass to override the configured `sequence_length` of | ||
| the layer. | ||
|
|
||
| Examples: | ||
|
|
||
| Directly calling the from_preset(). | ||
| ```python | ||
| preprocessor = keras_nlp.models.MistralPreprocessor.from_preset( | ||
| "mistral_base_en" | ||
| ) | ||
|
|
||
| # Tokenize and pack a single sentence. | ||
| preprocessor("The quick brown fox jumped.") | ||
|
|
||
| # Tokenize and a batch of single sentences. | ||
| preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) | ||
|
|
||
| # Preprocess a batch of sentence pairs. | ||
| # When handling multiple sequences, always convert to tensors first! | ||
| first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) | ||
| second = tf.constant(["The fox tripped.", "Oh look, a whale."]) | ||
| preprocessor((first, second)) | ||
| ``` | ||
|
|
||
| Mapping with `tf.data.Dataset`. | ||
| ```python | ||
| preprocessor = keras_nlp.models.MistralPreprocessor.from_preset( | ||
| "mistral_base_en" | ||
| ) | ||
| first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) | ||
| second = tf.constant(["The fox tripped.", "Oh look, a whale."]) | ||
| label = tf.constant([1, 1]) | ||
|
|
||
| # Map labeled single sentences. | ||
| ds = tf.data.Dataset.from_tensor_slices((first, label)) | ||
| ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
|
|
||
| # Map unlabeled single sentences. | ||
| ds = tf.data.Dataset.from_tensor_slices(first) | ||
| ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
|
|
||
| # Map labeled sentence pairs. | ||
| ds = tf.data.Dataset.from_tensor_slices(((first, second), label)) | ||
| ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
|
|
||
| # Map unlabeled sentence pairs. | ||
| ds = tf.data.Dataset.from_tensor_slices((first, second)) | ||
|
|
||
| # Watch out for tf.data's default unpacking of tuples here! | ||
| # Best to invoke the `preprocessor` directly in this case. | ||
| ds = ds.map( | ||
| lambda first, second: preprocessor(x=(first, second)), | ||
| num_parallel_calls=tf.data.AUTOTUNE, | ||
| ) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| tokenizer, | ||
| sequence_length=1024, | ||
| add_start_token=True, | ||
| add_end_token=False, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.tokenizer = tokenizer | ||
| self.add_start_token = add_start_token | ||
| self.add_end_token = add_end_token | ||
| self.sequence_length = sequence_length | ||
| self.packer = StartEndPacker( | ||
| start_value=self.tokenizer.start_token_id, | ||
| end_value=self.tokenizer.end_token_id, | ||
| sequence_length=sequence_length, | ||
| return_padding_mask=True, | ||
| ) | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "sequence_length": self.sequence_length, | ||
| "add_start_token": self.add_start_token, | ||
| "add_end_token": self.add_end_token, | ||
| } | ||
| ) | ||
| return config | ||
|
|
||
| def call( | ||
| self, | ||
| x, | ||
| y=None, | ||
| sample_weight=None, | ||
| sequence_length=None, | ||
| ): | ||
| x = convert_inputs_to_list_of_tensor_segments(x) | ||
| if len(x) != 1: | ||
| raise ValueError( | ||
| "Mistral requires each input feature to contain only " | ||
| f"one segment, but received {len(x)}. If you are using Mistral" | ||
| " for a multi-segment classification task, please refer to " | ||
| "classification models like BERT or RoBERTa." | ||
| ) | ||
| sequence_length = sequence_length or self.sequence_length | ||
| token_ids, padding_mask = self.packer( | ||
| self.tokenizer(x[0]), | ||
| sequence_length=sequence_length, | ||
| add_start_value=self.add_start_token, | ||
| add_end_value=self.add_end_token, | ||
| ) | ||
| x = { | ||
| "token_ids": token_ids, | ||
| "padding_mask": padding_mask, | ||
| } | ||
| return pack_x_y_sample_weight(x, y, sample_weight) | ||
|
|
||
| @classproperty | ||
| def tokenizer_cls(cls): | ||
| return MistralTokenizer | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| # Copyright 2023 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
|
|
||
| import pytest | ||
tirthasheshpatel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| from keras_nlp.models.mistral.mistral_preprocessor import MistralPreprocessor | ||
| from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer | ||
| from keras_nlp.tests.test_case import TestCase | ||
|
|
||
|
|
||
| class MistralPreprocessorTest(TestCase): | ||
| def setUp(self): | ||
| self.tokenizer = MistralTokenizer( | ||
| # Generated using create_mistral_test_proto.py | ||
| proto=os.path.join( | ||
| self.get_test_data_dir(), "mistral_test_vocab.spm" | ||
| ) | ||
| ) | ||
| self.init_kwargs = { | ||
| "tokenizer": self.tokenizer, | ||
| "sequence_length": 8, | ||
| } | ||
| self.input_data = ( | ||
| ["the quick brown fox"], | ||
| [1], # Pass through labels. | ||
| [1.0], # Pass through sample_weights. | ||
| ) | ||
|
|
||
| def test_preprocessor_basics(self): | ||
| self.run_preprocessing_layer_test( | ||
| cls=MistralPreprocessor, | ||
| init_kwargs=self.init_kwargs, | ||
| input_data=self.input_data, | ||
| expected_output=( | ||
| { | ||
| "token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], | ||
| "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], | ||
| }, | ||
| [1], # Pass through labels. | ||
| [1.0], # Pass through sample_weights. | ||
| ), | ||
| ) | ||
|
|
||
| def test_errors_for_2d_list_input(self): | ||
| preprocessor = MistralPreprocessor(**self.init_kwargs) | ||
| ambiguous_input = [["one", "two"], ["three", "four"]] | ||
| with self.assertRaises(ValueError): | ||
| preprocessor(ambiguous_input) | ||
|
|
||
| @pytest.mark.extra_large | ||
| def test_all_presets(self): | ||
tirthasheshpatel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for preset in MistralPreprocessor.presets: | ||
| self.run_preset_test( | ||
| cls=MistralPreprocessor, | ||
| preset=preset, | ||
| input_data=self.input_data, | ||
| ) | ||
tirthasheshpatel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.