diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index a479c21004..2a425245e7 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -21,6 +21,10 @@ from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer from keras_nlp.models.bart.bart_backbone import BartBackbone +from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor +from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import ( + BartSeq2SeqLMPreprocessor, +) from keras_nlp.models.bart.bart_tokenizer import BartTokenizer from keras_nlp.models.bert.bert_backbone import BertBackbone from keras_nlp.models.bert.bert_classifier import BertClassifier diff --git a/keras_nlp/models/bart/bart_preprocessor.py b/keras_nlp/models/bart/bart_preprocessor.py new file mode 100644 index 0000000000..f1da3e4fba --- /dev/null +++ b/keras_nlp/models/bart/bart_preprocessor.py @@ -0,0 +1,288 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BART preprocessor layer.""" + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker +from keras_nlp.models.bart.bart_presets import backbone_presets +from keras_nlp.models.bart.bart_tokenizer import BartTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.BartPreprocessor") +class BartPreprocessor(Preprocessor): + """A BART preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do three things: + + 1. Tokenize both encoder inputs and decoder inputs using the `tokenizer`. + Both inputs can contain only one segment. + 2. Add the appropriate special tokens - `""`, `""` and `""`. + 3. Construct a dictionary with keys `"encoder_token_ids"`, + `"encoder_padding_mask"`, `"decoder_token_ids"`, `"decoder_padding_mask"` + that can be passed directly to a BART model. + + Args: + tokenizer: A `keras_nlp.models.BartTokenizer` instance. + encoder_sequence_length: The length of the packed encoder inputs. + decoder_sequence_length: The length of the packed decoder inputs. + truncate: string. The algorithm to truncate a list of batched segments + to fit within `sequence_length`. The value can be either + `round_robin` or `waterfall`: + - `"round_robin"`: Available space is assigned one token at a + time in a round-robin fashion to the inputs that still need + some, until the limit is reached. + - `"waterfall"`: The allocation of the budget is done using a + "waterfall" algorithm that allocates quota in a + left-to-right manner and fills up the buckets until we run + out of budget. It supports an arbitrary number of segments. + + Call arguments: + x: A dictionary with `encoder_text` and `decoder_text` as its keys. + Each value in the dictionary should be a tensor of single string + sequences. Inputs may be batched or unbatched. Raw python inputs + will be converted to tensors. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + + Examples: + + Directly calling the layer on data + ```python + preprocessor = keras_nlp.models.BartPreprocessor.from_preset("bart_base_en") + + # Preprocess unbatched inputs. + inputs = { + "encoder_text": "The fox was sleeping.", + "decoder_text": "The fox was awake." + } + preprocessor(inputs) + + # Preprocess batched inputs. + inputs = { + "encoder_text": ["The fox was sleeping.", "The lion was quiet."], + "decoder_text": ["The fox was awake.", "The lion was roaring."] + } + preprocessor(inputs) + + # Custom vocabulary. + vocab = { + "": 0, + "": 1, + "": 2, + "Ġafter": 5, + "noon": 6, + "Ġsun": 7, + } + merges = ["Ġ a", "Ġ s", "Ġ n", "e r", "n o", "o n", "Ġs u", "Ġa f", "no on"] + merges += ["Ġsu n", "Ġaf t", "Ġaft er"] + + tokenizer = keras_nlp.models.BartTokenizer( + vocabulary=vocab, + merges=merges, + ) + preprocessor = keras_nlp.models.BartPreprocessor( + tokenizer=tokenizer, + encoder_sequence_length=20, + decoder_sequence_length=10, + ) + inputs = { + "encoder_text": "The fox was sleeping.", + "decoder_text": "The fox was awake." + } + preprocessor(inputs) + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.BartPreprocessor.from_preset("bart_base_en") + + # Map labeled single sentences. + features = { + "encoder_text": tf.constant( + ["The fox was sleeping.", "The lion was quiet."] + ), + "decoder_text": tf.constant( + ["The fox was awake.", "The lion was silent."] + ) + } + labels = tf.constant(["True", "False"]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + features = { + "encoder_text": tf.constant( + ["The fox was sleeping.", "The lion was quiet."] + ), + "decoder_text": tf.constant( + ["The fox was awake.", "The lion was roaring."] + ) + } + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def __init__( + self, + tokenizer, + encoder_sequence_length=1024, + decoder_sequence_length=1024, + truncate="round_robin", + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + + self.encoder_packer = MultiSegmentPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + truncate=truncate, + sequence_length=encoder_sequence_length, + ) + self.decoder_packer = MultiSegmentPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + truncate=truncate, + sequence_length=decoder_sequence_length, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "encoder_sequence_length": self.encoder_packer.sequence_length, + "decoder_sequence_length": self.decoder_packer.sequence_length, + "truncate": self.encoder_packer.truncate, + } + ) + return config + + def call(self, x, y=None, sample_weight=None): + if not ( + isinstance(x, dict) + and ["encoder_text", "decoder_text"] == list(x.keys()) + ): + raise ValueError( + '`x` must be a dictionary, containing the keys `"encoder_text"`' + f' and `"decoder_text"`. Received x={x}.' + ) + + encoder_text = x["encoder_text"] + decoder_text = x["decoder_text"] + + encoder_text = convert_inputs_to_list_of_tensor_segments(encoder_text) + decoder_text = convert_inputs_to_list_of_tensor_segments(decoder_text) + + if len(encoder_text) > 1 or len(decoder_text) > 1: + raise ValueError( + '`BARTPreprocessor` requires both `"encoder_text"` and ' + f'`"decoder_text"` to contain only one segment, but received ' + f"{len(encoder_text)} and {len(decoder_text)}, respectively." + ) + + encoder_inputs = [self.tokenizer(segment) for segment in encoder_text] + encoder_token_ids, _ = self.encoder_packer(encoder_inputs) + + decoder_inputs = [self.tokenizer(segment) for segment in decoder_text] + decoder_token_ids, _ = self.decoder_packer(decoder_inputs) + + x = { + "encoder_token_ids": encoder_token_ids, + "encoder_padding_mask": encoder_token_ids + != self.tokenizer.pad_token_id, + "decoder_token_ids": decoder_token_ids, + "decoder_padding_mask": decoder_token_ids + != self.tokenizer.pad_token_id, + } + + return pack_x_y_sample_weight(x, y, sample_weight) + + @classproperty + def tokenizer_cls(cls): + return BartTokenizer + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @classmethod + def from_preset( + cls, + preset, + **kwargs, + ): + # Override base class's `from_preset` to handle `encoder_sequence_length` + # and `decoder_sequence_length`. + if not cls.presets: + raise NotImplementedError( + "No presets have been created for this class." + ) + if preset not in cls.presets: + raise ValueError( + "`preset` must be one of " + f"""{", ".join(cls.presets)}. Received: {preset}.""" + ) + + tokenizer = cls.tokenizer_cls.from_preset(preset) + + metadata = cls.presets[preset] + # For task model presets, the backbone config is nested. + if "backbone" in metadata["config"]: + backbone_config = metadata["config"]["backbone"]["config"] + else: + backbone_config = metadata["config"] + + # Use model's `max_sequence_length` if either `encoder_sequence_length` + # or `decoder_sequence_length` are unspecified; otherwise check that + # `encoder_sequence_length`/`decoder_sequence_length` are not too long. + encoder_sequence_length = kwargs.pop("encoder_sequence_length", None) + decoder_sequence_length = kwargs.pop("decoder_sequence_length", None) + max_sequence_length = backbone_config["max_sequence_length"] + + def check_sequence_length(sequence_length, name): + if sequence_length is not None: + if sequence_length > max_sequence_length: + raise ValueError( + f"`{name}` cannot be longer than `{preset}` " + f"preset's `max_sequence_length` of {max_sequence_length}. " + f"Received: {sequence_length}." + ) + return sequence_length + else: + return max_sequence_length + + encoder_sequence_length = check_sequence_length( + encoder_sequence_length, "encoder_sequence_length" + ) + decoder_sequence_length = check_sequence_length( + decoder_sequence_length, "decoder_sequence_length" + ) + + return cls( + tokenizer=tokenizer, + encoder_sequence_length=encoder_sequence_length, + decoder_sequence_length=decoder_sequence_length, + **kwargs, + ) diff --git a/keras_nlp/models/bart/bart_preprocessor_test.py b/keras_nlp/models/bart/bart_preprocessor_test.py new file mode 100644 index 0000000000..6946374e22 --- /dev/null +++ b/keras_nlp/models/bart/bart_preprocessor_test.py @@ -0,0 +1,202 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BART preprocessor layer.""" + +import os + +import pytest +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor +from keras_nlp.models.bart.bart_tokenizer import BartTokenizer + + +class BartPreprocessorTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + vocab = { + "": 0, + "": 1, + "": 2, + "Ġair": 3, + "plane": 4, + "Ġat": 5, + "port": 6, + "Ġkoh": 7, + "li": 8, + "Ġis": 9, + "Ġthe": 10, + "Ġbest": 11, + "": 12, + } + + merges = ["Ġ a", "Ġ t", "Ġ k", "Ġ i", "Ġ b", "Ġa i", "p l", "n e"] + merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"] + merges += ["Ġt h", "Ġai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"] + merges += ["pla ne"] + + self.preprocessor = BartPreprocessor( + tokenizer=BartTokenizer( + vocabulary=vocab, + merges=merges, + ), + encoder_sequence_length=10, + decoder_sequence_length=8, + ) + + def test_tokenize_strings(self): + input_data = { + "encoder_text": " airplane at airport", + "decoder_text": " kohli is the best", + } + + output = self.preprocessor(input_data) + self.assertAllEqual( + output["encoder_token_ids"], [0, 3, 4, 5, 3, 6, 2, 1, 1, 1] + ) + self.assertAllEqual( + output["encoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0] + ) + self.assertAllEqual( + output["decoder_token_ids"], [0, 7, 8, 9, 10, 11, 2, 1] + ) + self.assertAllEqual( + output["decoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0] + ) + + def test_tokenize_list_of_strings(self): + input_data = { + "encoder_text": [" airplane at airport"] * 4, + "decoder_text": [" kohli is the best"] * 4, + } + + output = self.preprocessor(input_data) + self.assertAllEqual( + output["encoder_token_ids"], [[0, 3, 4, 5, 3, 6, 2, 1, 1, 1]] * 4 + ) + self.assertAllEqual( + output["encoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4 + ) + self.assertAllEqual( + output["decoder_token_ids"], [[0, 7, 8, 9, 10, 11, 2, 1]] * 4 + ) + self.assertAllEqual( + output["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4 + ) + + def test_tokenize_labeled_batch(self): + x = { + "encoder_text": [" airplane at airport"] * 4, + "decoder_text": [" kohli is the best"] * 4, + } + y = tf.constant([1] * 4) + sw = tf.constant([1.0] * 4) + x_out, y_out, sw_out = self.preprocessor(x, y, sw) + self.assertAllEqual( + x_out["encoder_token_ids"], [[0, 3, 4, 5, 3, 6, 2, 1, 1, 1]] * 4 + ) + self.assertAllEqual( + x_out["encoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4 + ) + self.assertAllEqual( + x_out["decoder_token_ids"], [[0, 7, 8, 9, 10, 11, 2, 1]] * 4 + ) + self.assertAllEqual( + x_out["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4 + ) + self.assertAllEqual(y_out, y) + self.assertAllEqual(sw_out, sw) + + def test_tokenize_labeled_dataset(self): + x = { + "encoder_text": [" airplane at airport"] * 4, + "decoder_text": [" kohli is the best"] * 4, + } + y = tf.constant([1] * 4) + sw = tf.constant([1.0] * 4) + ds = tf.data.Dataset.from_tensor_slices((x, y, sw)) + ds = ds.map(self.preprocessor) + x_out, y_out, sw_out = ds.batch(4).take(1).get_single_element() + self.assertAllEqual( + x_out["encoder_token_ids"], [[0, 3, 4, 5, 3, 6, 2, 1, 1, 1]] * 4 + ) + self.assertAllEqual( + x_out["encoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4 + ) + self.assertAllEqual( + x_out["decoder_token_ids"], [[0, 7, 8, 9, 10, 11, 2, 1]] * 4 + ) + self.assertAllEqual( + x_out["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4 + ) + self.assertAllEqual(y_out, y) + self.assertAllEqual(sw_out, sw) + + def test_error_multi_segment_input(self): + input_data = { + "encoder_text": ( + tf.constant([" airplane at airport"] * 2), + tf.constant([" airplane"] * 2), + ), + "decoder_text": ( + tf.constant([" kohli is the best"] * 2), + tf.constant([" kohli"] * 2), + ), + } + + with self.assertRaises(ValueError): + self.preprocessor(input_data) + + def test_serialization(self): + new_preprocessor = keras.utils.deserialize_keras_object( + keras.utils.serialize_keras_object(self.preprocessor) + ) + self.assertEqual( + new_preprocessor.get_config(), self.preprocessor.get_config() + ) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + @pytest.mark.large + def test_saved_model(self, save_format, filename): + input_data = { + "encoder_text": tf.constant(" airplane at airport"), + "decoder_text": tf.constant(" kohli is the best"), + } + + inputs = { + "encoder_text": keras.Input(dtype="string", shape=()), + "decoder_text": keras.Input(dtype="string", shape=()), + } + outputs = self.preprocessor(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + path = os.path.join(self.get_temp_dir(), filename) + # Don't save traces in the tf format, we check compilation elsewhere. + kwargs = {"save_traces": False} if save_format == "tf" else {} + model.save(path, save_format=save_format, **kwargs) + + restored_model = keras.models.load_model(path) + + model_output = model(input_data) + restored_model_output = restored_model(input_data) + + self.assertAllClose( + model_output, + restored_model_output, + ) diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py new file mode 100644 index 0000000000..41ca5e61e5 --- /dev/null +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py @@ -0,0 +1,187 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BART Seq2Seq LM preprocessor layer.""" + +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.BartSeq2SeqLMPreprocessor") +class BartSeq2SeqLMPreprocessor(BartPreprocessor): + """BART Seq2Seq LM preprocessor. + + This layer is used as preprocessor for seq2seq tasks using the BART model. + This class subclasses `keras_nlp.models.BartPreprocessor` and keeps most of + its functionality. It has two changes from the superclass: + + 1. Sets the `y` (label) and `sample_weights` fields by shifting the + decoder input sequence one step towards the left. Both these fields are + inferred internally, and any passed values will be ignored. + 2. Drops the last token from the decoder input sequence as it does not have + a successor. + + Args: + tokenizer: A `keras_nlp.models.BartTokenizer` instance. + encoder_sequence_length: The length of the packed encoder inputs. + decoder_sequence_length: The length of the packed decoder inputs. + truncate: string. The algorithm to truncate a list of batched segments + to fit within `sequence_length`. The value can be either + `round_robin` or `waterfall`: + - `"round_robin"`: Available space is assigned one token at a + time in a round-robin fashion to the inputs that still need + some, until the limit is reached. + - `"waterfall"`: The allocation of the budget is done using a + "waterfall" algorithm that allocates quota in a + left-to-right manner and fills up the buckets until we run + out of budget. It supports an arbitrary number of segments. + + Call arguments: + x: A dictionary with `encoder_text` and `decoder_text` as its keys. + Each value in the dictionary should be a tensor of single string + sequences. Inputs may be batched or unbatched. Raw python inputs + will be converted to tensors. + y: Label data. Should always be `None` as the layer generates labels by + shifting the decoder input sequence one step to the left. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights by shifting the padding mask one step to the + left. + + Examples: + + Directly calling the layer on data + ```python + preprocessor = keras_nlp.models.BartPreprocessor.from_preset("bart_base_en") + + # Preprocess unbatched inputs. + inputs = { + "encoder_text": "The fox was sleeping.", + "decoder_text": "The fox was awake." + } + preprocessor(inputs) + + # Preprocess batched inputs. + inputs = { + "encoder_text": ["The fox was sleeping.", "The lion was quiet."], + "decoder_text": ["The fox was awake.", "The lion was roaring."] + } + preprocessor(inputs) + + # Custom vocabulary. + vocab = { + "": 0, + "": 1, + "": 2, + "Ġafter": 5, + "noon": 6, + "Ġsun": 7, + } + merges = ["Ġ a", "Ġ s", "Ġ n", "e r", "n o", "o n", "Ġs u", "Ġa f", "no on"] + merges += ["Ġsu n", "Ġaf t", "Ġaft er"] + + tokenizer = keras_nlp.models.BartTokenizer( + vocabulary=vocab, + merges=merges, + ) + preprocessor = keras_nlp.models.BartPreprocessor( + tokenizer=tokenizer, + encoder_sequence_length=20, + decoder_sequence_length=10, + ) + inputs = { + "encoder_text": "The fox was sleeping.", + "decoder_text": "The fox was awake." + } + preprocessor(inputs) + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.BartPreprocessor.from_preset("bart_base_en") + + # Map single sentences. + features = { + "encoder_text": tf.constant( + ["The fox was sleeping.", "The lion was quiet."] + ), + "decoder_text": tf.constant( + ["The fox was awake.", "The lion was roaring."] + ) + } + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def __init__( + self, + tokenizer, + encoder_sequence_length, + decoder_sequence_length, + truncate="round_robin", + **kwargs + ): + # Since we truncate the last token from `decoder_token_ids`, we need to + # forcefully set the `decoder_sequence_length` to one greater than the + # value passed. + super().__init__( + tokenizer=tokenizer, + encoder_sequence_length=encoder_sequence_length, + decoder_sequence_length=decoder_sequence_length + 1, + truncate=truncate, + **kwargs + ) + + # Maintain a private copy of `decoder_sequence_length` for config + # purposes. + self._decoder_sequence_length = decoder_sequence_length + + def get_config(self): + config = super().get_config() + config.update( + { + "encoder_sequence_length": self.encoder_packer.sequence_length, + "decoder_sequence_length": self._decoder_sequence_length, + "truncate": self.encoder_packer.truncate, + } + ) + return config + + def call(self, x, y=None, sample_weight=None): + if y is not None or sample_weight is not None: + logging.warning( + "`BartSeq2SeqLMPreprocessor` infers `y` and `sample_weight` " + "from the provided input data, i.e., `x`. However, non-`None`" + "values have been passed for `y` or `sample_weight` or both. " + "These values will be ignored." + ) + + x = super().call(x) + decoder_token_ids = x.pop("decoder_token_ids") + decoder_padding_mask = x.pop("decoder_padding_mask") + + # The last token does not have a next token. Hence, we truncate it. + x = { + **x, + "decoder_token_ids": decoder_token_ids[..., :-1], + "decoder_padding_mask": decoder_padding_mask[..., :-1], + } + # Target `y` will be the decoder input sequence shifted one step to the + # left (i.e., the next token). + y = decoder_token_ids[..., 1:] + sample_weight = decoder_padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py new file mode 100644 index 0000000000..dd6ecfde91 --- /dev/null +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py @@ -0,0 +1,161 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BART preprocessor layer.""" + +import os + +import pytest +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import ( + BartSeq2SeqLMPreprocessor, +) +from keras_nlp.models.bart.bart_tokenizer import BartTokenizer + + +class BartSeq2SeqLMPreprocessorTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + vocab = { + "": 0, + "": 1, + "": 2, + "Ġair": 3, + "plane": 4, + "Ġat": 5, + "port": 6, + "Ġkoh": 7, + "li": 8, + "Ġis": 9, + "Ġthe": 10, + "Ġbest": 11, + "": 12, + } + + merges = ["Ġ a", "Ġ t", "Ġ k", "Ġ i", "Ġ b", "Ġa i", "p l", "n e"] + merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"] + merges += ["Ġt h", "Ġai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"] + merges += ["pla ne"] + + self.preprocessor = BartSeq2SeqLMPreprocessor( + tokenizer=BartTokenizer( + vocabulary=vocab, + merges=merges, + ), + encoder_sequence_length=10, + decoder_sequence_length=8, + ) + + def test_tokenize_strings(self): + input_data = { + "encoder_text": " airplane at airport", + "decoder_text": " kohli is the best", + } + + x_out, y_out, sw_out = self.preprocessor(input_data) + self.assertAllEqual( + x_out["encoder_token_ids"], [0, 3, 4, 5, 3, 6, 2, 1, 1, 1] + ) + self.assertAllEqual( + x_out["encoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0] + ) + self.assertAllEqual( + x_out["decoder_token_ids"], [0, 7, 8, 9, 10, 11, 2, 1] + ) + self.assertAllEqual( + x_out["decoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0] + ) + self.assertAllEqual(y_out, [7, 8, 9, 10, 11, 2, 1, 1]) + self.assertAllEqual(sw_out, [1, 1, 1, 1, 1, 1, 0, 0]) + + def test_tokenize_list_of_strings(self): + input_data = { + "encoder_text": [" airplane at airport"] * 4, + "decoder_text": [" kohli is the best"] * 4, + } + + x_out, y_out, sw_out = self.preprocessor(input_data) + self.assertAllEqual( + x_out["encoder_token_ids"], [[0, 3, 4, 5, 3, 6, 2, 1, 1, 1]] * 4 + ) + self.assertAllEqual( + x_out["encoder_padding_mask"], + [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4, + ) + self.assertAllEqual( + x_out["decoder_token_ids"], [[0, 7, 8, 9, 10, 11, 2, 1]] * 4 + ) + self.assertAllEqual( + x_out["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4 + ) + self.assertAllEqual(y_out, [[7, 8, 9, 10, 11, 2, 1, 1]] * 4) + self.assertAllEqual(sw_out, [[1, 1, 1, 1, 1, 1, 0, 0]] * 4) + + def test_error_multi_segment_input(self): + input_data = { + "encoder_text": ( + tf.constant([" airplane at airport"] * 2), + tf.constant([" airplane"] * 2), + ), + "decoder_text": ( + tf.constant([" kohli is the best"] * 2), + tf.constant([" kohli"] * 2), + ), + } + + with self.assertRaises(ValueError): + self.preprocessor(input_data) + + def test_serialization(self): + new_preprocessor = keras.utils.deserialize_keras_object( + keras.utils.serialize_keras_object(self.preprocessor) + ) + self.assertEqual( + new_preprocessor.get_config(), self.preprocessor.get_config() + ) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + @pytest.mark.large + def test_saved_model(self, save_format, filename): + input_data = { + "encoder_text": tf.constant(" airplane at airport"), + "decoder_text": tf.constant(" kohli is the best"), + } + + inputs = { + "encoder_text": keras.Input(dtype="string", shape=()), + "decoder_text": keras.Input(dtype="string", shape=()), + } + outputs = self.preprocessor(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + path = os.path.join(self.get_temp_dir(), filename) + # Don't save traces in the tf format, we check compilation elsewhere. + kwargs = {"save_traces": False} if save_format == "tf" else {} + model.save(path, save_format=save_format, **kwargs) + + restored_model = keras.models.load_model(path) + + model_output = model(input_data) + restored_model_output = restored_model(input_data) + + self.assertAllClose( + model_output, + restored_model_output, + )