diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py
index a479c21004..2a425245e7 100644
--- a/keras_nlp/models/__init__.py
+++ b/keras_nlp/models/__init__.py
@@ -21,6 +21,10 @@
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
from keras_nlp.models.bart.bart_backbone import BartBackbone
+from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor
+from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import (
+ BartSeq2SeqLMPreprocessor,
+)
from keras_nlp.models.bart.bart_tokenizer import BartTokenizer
from keras_nlp.models.bert.bert_backbone import BertBackbone
from keras_nlp.models.bert.bert_classifier import BertClassifier
diff --git a/keras_nlp/models/bart/bart_preprocessor.py b/keras_nlp/models/bart/bart_preprocessor.py
new file mode 100644
index 0000000000..f1da3e4fba
--- /dev/null
+++ b/keras_nlp/models/bart/bart_preprocessor.py
@@ -0,0 +1,288 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""BART preprocessor layer."""
+
+import copy
+
+from keras_nlp.api_export import keras_nlp_export
+from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
+from keras_nlp.models.bart.bart_presets import backbone_presets
+from keras_nlp.models.bart.bart_tokenizer import BartTokenizer
+from keras_nlp.models.preprocessor import Preprocessor
+from keras_nlp.utils.keras_utils import (
+ convert_inputs_to_list_of_tensor_segments,
+)
+from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
+from keras_nlp.utils.python_utils import classproperty
+
+
+@keras_nlp_export("keras_nlp.models.BartPreprocessor")
+class BartPreprocessor(Preprocessor):
+ """A BART preprocessing layer which tokenizes and packs inputs.
+
+ This preprocessing layer will do three things:
+
+ 1. Tokenize both encoder inputs and decoder inputs using the `tokenizer`.
+ Both inputs can contain only one segment.
+ 2. Add the appropriate special tokens - `""`, `""` and `""`.
+ 3. Construct a dictionary with keys `"encoder_token_ids"`,
+ `"encoder_padding_mask"`, `"decoder_token_ids"`, `"decoder_padding_mask"`
+ that can be passed directly to a BART model.
+
+ Args:
+ tokenizer: A `keras_nlp.models.BartTokenizer` instance.
+ encoder_sequence_length: The length of the packed encoder inputs.
+ decoder_sequence_length: The length of the packed decoder inputs.
+ truncate: string. The algorithm to truncate a list of batched segments
+ to fit within `sequence_length`. The value can be either
+ `round_robin` or `waterfall`:
+ - `"round_robin"`: Available space is assigned one token at a
+ time in a round-robin fashion to the inputs that still need
+ some, until the limit is reached.
+ - `"waterfall"`: The allocation of the budget is done using a
+ "waterfall" algorithm that allocates quota in a
+ left-to-right manner and fills up the buckets until we run
+ out of budget. It supports an arbitrary number of segments.
+
+ Call arguments:
+ x: A dictionary with `encoder_text` and `decoder_text` as its keys.
+ Each value in the dictionary should be a tensor of single string
+ sequences. Inputs may be batched or unbatched. Raw python inputs
+ will be converted to tensors.
+ y: Any label data. Will be passed through unaltered.
+ sample_weight: Any label weight data. Will be passed through unaltered.
+
+ Examples:
+
+ Directly calling the layer on data
+ ```python
+ preprocessor = keras_nlp.models.BartPreprocessor.from_preset("bart_base_en")
+
+ # Preprocess unbatched inputs.
+ inputs = {
+ "encoder_text": "The fox was sleeping.",
+ "decoder_text": "The fox was awake."
+ }
+ preprocessor(inputs)
+
+ # Preprocess batched inputs.
+ inputs = {
+ "encoder_text": ["The fox was sleeping.", "The lion was quiet."],
+ "decoder_text": ["The fox was awake.", "The lion was roaring."]
+ }
+ preprocessor(inputs)
+
+ # Custom vocabulary.
+ vocab = {
+ "": 0,
+ "": 1,
+ "": 2,
+ "Ġafter": 5,
+ "noon": 6,
+ "Ġsun": 7,
+ }
+ merges = ["Ġ a", "Ġ s", "Ġ n", "e r", "n o", "o n", "Ġs u", "Ġa f", "no on"]
+ merges += ["Ġsu n", "Ġaf t", "Ġaft er"]
+
+ tokenizer = keras_nlp.models.BartTokenizer(
+ vocabulary=vocab,
+ merges=merges,
+ )
+ preprocessor = keras_nlp.models.BartPreprocessor(
+ tokenizer=tokenizer,
+ encoder_sequence_length=20,
+ decoder_sequence_length=10,
+ )
+ inputs = {
+ "encoder_text": "The fox was sleeping.",
+ "decoder_text": "The fox was awake."
+ }
+ preprocessor(inputs)
+ ```
+
+ Mapping with `tf.data.Dataset`.
+ ```python
+ preprocessor = keras_nlp.models.BartPreprocessor.from_preset("bart_base_en")
+
+ # Map labeled single sentences.
+ features = {
+ "encoder_text": tf.constant(
+ ["The fox was sleeping.", "The lion was quiet."]
+ ),
+ "decoder_text": tf.constant(
+ ["The fox was awake.", "The lion was silent."]
+ )
+ }
+ labels = tf.constant(["True", "False"])
+ ds = tf.data.Dataset.from_tensor_slices((features, labels))
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
+
+ # Map unlabeled single sentences.
+ features = {
+ "encoder_text": tf.constant(
+ ["The fox was sleeping.", "The lion was quiet."]
+ ),
+ "decoder_text": tf.constant(
+ ["The fox was awake.", "The lion was roaring."]
+ )
+ }
+ ds = tf.data.Dataset.from_tensor_slices(features)
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
+ ```
+ """
+
+ def __init__(
+ self,
+ tokenizer,
+ encoder_sequence_length=1024,
+ decoder_sequence_length=1024,
+ truncate="round_robin",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.tokenizer = tokenizer
+
+ self.encoder_packer = MultiSegmentPacker(
+ start_value=self.tokenizer.start_token_id,
+ end_value=self.tokenizer.end_token_id,
+ pad_value=self.tokenizer.pad_token_id,
+ truncate=truncate,
+ sequence_length=encoder_sequence_length,
+ )
+ self.decoder_packer = MultiSegmentPacker(
+ start_value=self.tokenizer.start_token_id,
+ end_value=self.tokenizer.end_token_id,
+ pad_value=self.tokenizer.pad_token_id,
+ truncate=truncate,
+ sequence_length=decoder_sequence_length,
+ )
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "encoder_sequence_length": self.encoder_packer.sequence_length,
+ "decoder_sequence_length": self.decoder_packer.sequence_length,
+ "truncate": self.encoder_packer.truncate,
+ }
+ )
+ return config
+
+ def call(self, x, y=None, sample_weight=None):
+ if not (
+ isinstance(x, dict)
+ and ["encoder_text", "decoder_text"] == list(x.keys())
+ ):
+ raise ValueError(
+ '`x` must be a dictionary, containing the keys `"encoder_text"`'
+ f' and `"decoder_text"`. Received x={x}.'
+ )
+
+ encoder_text = x["encoder_text"]
+ decoder_text = x["decoder_text"]
+
+ encoder_text = convert_inputs_to_list_of_tensor_segments(encoder_text)
+ decoder_text = convert_inputs_to_list_of_tensor_segments(decoder_text)
+
+ if len(encoder_text) > 1 or len(decoder_text) > 1:
+ raise ValueError(
+ '`BARTPreprocessor` requires both `"encoder_text"` and '
+ f'`"decoder_text"` to contain only one segment, but received '
+ f"{len(encoder_text)} and {len(decoder_text)}, respectively."
+ )
+
+ encoder_inputs = [self.tokenizer(segment) for segment in encoder_text]
+ encoder_token_ids, _ = self.encoder_packer(encoder_inputs)
+
+ decoder_inputs = [self.tokenizer(segment) for segment in decoder_text]
+ decoder_token_ids, _ = self.decoder_packer(decoder_inputs)
+
+ x = {
+ "encoder_token_ids": encoder_token_ids,
+ "encoder_padding_mask": encoder_token_ids
+ != self.tokenizer.pad_token_id,
+ "decoder_token_ids": decoder_token_ids,
+ "decoder_padding_mask": decoder_token_ids
+ != self.tokenizer.pad_token_id,
+ }
+
+ return pack_x_y_sample_weight(x, y, sample_weight)
+
+ @classproperty
+ def tokenizer_cls(cls):
+ return BartTokenizer
+
+ @classproperty
+ def presets(cls):
+ return copy.deepcopy(backbone_presets)
+
+ @classmethod
+ def from_preset(
+ cls,
+ preset,
+ **kwargs,
+ ):
+ # Override base class's `from_preset` to handle `encoder_sequence_length`
+ # and `decoder_sequence_length`.
+ if not cls.presets:
+ raise NotImplementedError(
+ "No presets have been created for this class."
+ )
+ if preset not in cls.presets:
+ raise ValueError(
+ "`preset` must be one of "
+ f"""{", ".join(cls.presets)}. Received: {preset}."""
+ )
+
+ tokenizer = cls.tokenizer_cls.from_preset(preset)
+
+ metadata = cls.presets[preset]
+ # For task model presets, the backbone config is nested.
+ if "backbone" in metadata["config"]:
+ backbone_config = metadata["config"]["backbone"]["config"]
+ else:
+ backbone_config = metadata["config"]
+
+ # Use model's `max_sequence_length` if either `encoder_sequence_length`
+ # or `decoder_sequence_length` are unspecified; otherwise check that
+ # `encoder_sequence_length`/`decoder_sequence_length` are not too long.
+ encoder_sequence_length = kwargs.pop("encoder_sequence_length", None)
+ decoder_sequence_length = kwargs.pop("decoder_sequence_length", None)
+ max_sequence_length = backbone_config["max_sequence_length"]
+
+ def check_sequence_length(sequence_length, name):
+ if sequence_length is not None:
+ if sequence_length > max_sequence_length:
+ raise ValueError(
+ f"`{name}` cannot be longer than `{preset}` "
+ f"preset's `max_sequence_length` of {max_sequence_length}. "
+ f"Received: {sequence_length}."
+ )
+ return sequence_length
+ else:
+ return max_sequence_length
+
+ encoder_sequence_length = check_sequence_length(
+ encoder_sequence_length, "encoder_sequence_length"
+ )
+ decoder_sequence_length = check_sequence_length(
+ decoder_sequence_length, "decoder_sequence_length"
+ )
+
+ return cls(
+ tokenizer=tokenizer,
+ encoder_sequence_length=encoder_sequence_length,
+ decoder_sequence_length=decoder_sequence_length,
+ **kwargs,
+ )
diff --git a/keras_nlp/models/bart/bart_preprocessor_test.py b/keras_nlp/models/bart/bart_preprocessor_test.py
new file mode 100644
index 0000000000..6946374e22
--- /dev/null
+++ b/keras_nlp/models/bart/bart_preprocessor_test.py
@@ -0,0 +1,202 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for BART preprocessor layer."""
+
+import os
+
+import pytest
+import tensorflow as tf
+from absl.testing import parameterized
+from tensorflow import keras
+
+from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor
+from keras_nlp.models.bart.bart_tokenizer import BartTokenizer
+
+
+class BartPreprocessorTest(tf.test.TestCase, parameterized.TestCase):
+ def setUp(self):
+ vocab = {
+ "": 0,
+ "": 1,
+ "": 2,
+ "Ġair": 3,
+ "plane": 4,
+ "Ġat": 5,
+ "port": 6,
+ "Ġkoh": 7,
+ "li": 8,
+ "Ġis": 9,
+ "Ġthe": 10,
+ "Ġbest": 11,
+ "": 12,
+ }
+
+ merges = ["Ġ a", "Ġ t", "Ġ k", "Ġ i", "Ġ b", "Ġa i", "p l", "n e"]
+ merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"]
+ merges += ["Ġt h", "Ġai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"]
+ merges += ["pla ne"]
+
+ self.preprocessor = BartPreprocessor(
+ tokenizer=BartTokenizer(
+ vocabulary=vocab,
+ merges=merges,
+ ),
+ encoder_sequence_length=10,
+ decoder_sequence_length=8,
+ )
+
+ def test_tokenize_strings(self):
+ input_data = {
+ "encoder_text": " airplane at airport",
+ "decoder_text": " kohli is the best",
+ }
+
+ output = self.preprocessor(input_data)
+ self.assertAllEqual(
+ output["encoder_token_ids"], [0, 3, 4, 5, 3, 6, 2, 1, 1, 1]
+ )
+ self.assertAllEqual(
+ output["encoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
+ )
+ self.assertAllEqual(
+ output["decoder_token_ids"], [0, 7, 8, 9, 10, 11, 2, 1]
+ )
+ self.assertAllEqual(
+ output["decoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0]
+ )
+
+ def test_tokenize_list_of_strings(self):
+ input_data = {
+ "encoder_text": [" airplane at airport"] * 4,
+ "decoder_text": [" kohli is the best"] * 4,
+ }
+
+ output = self.preprocessor(input_data)
+ self.assertAllEqual(
+ output["encoder_token_ids"], [[0, 3, 4, 5, 3, 6, 2, 1, 1, 1]] * 4
+ )
+ self.assertAllEqual(
+ output["encoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4
+ )
+ self.assertAllEqual(
+ output["decoder_token_ids"], [[0, 7, 8, 9, 10, 11, 2, 1]] * 4
+ )
+ self.assertAllEqual(
+ output["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4
+ )
+
+ def test_tokenize_labeled_batch(self):
+ x = {
+ "encoder_text": [" airplane at airport"] * 4,
+ "decoder_text": [" kohli is the best"] * 4,
+ }
+ y = tf.constant([1] * 4)
+ sw = tf.constant([1.0] * 4)
+ x_out, y_out, sw_out = self.preprocessor(x, y, sw)
+ self.assertAllEqual(
+ x_out["encoder_token_ids"], [[0, 3, 4, 5, 3, 6, 2, 1, 1, 1]] * 4
+ )
+ self.assertAllEqual(
+ x_out["encoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4
+ )
+ self.assertAllEqual(
+ x_out["decoder_token_ids"], [[0, 7, 8, 9, 10, 11, 2, 1]] * 4
+ )
+ self.assertAllEqual(
+ x_out["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4
+ )
+ self.assertAllEqual(y_out, y)
+ self.assertAllEqual(sw_out, sw)
+
+ def test_tokenize_labeled_dataset(self):
+ x = {
+ "encoder_text": [" airplane at airport"] * 4,
+ "decoder_text": [" kohli is the best"] * 4,
+ }
+ y = tf.constant([1] * 4)
+ sw = tf.constant([1.0] * 4)
+ ds = tf.data.Dataset.from_tensor_slices((x, y, sw))
+ ds = ds.map(self.preprocessor)
+ x_out, y_out, sw_out = ds.batch(4).take(1).get_single_element()
+ self.assertAllEqual(
+ x_out["encoder_token_ids"], [[0, 3, 4, 5, 3, 6, 2, 1, 1, 1]] * 4
+ )
+ self.assertAllEqual(
+ x_out["encoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4
+ )
+ self.assertAllEqual(
+ x_out["decoder_token_ids"], [[0, 7, 8, 9, 10, 11, 2, 1]] * 4
+ )
+ self.assertAllEqual(
+ x_out["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4
+ )
+ self.assertAllEqual(y_out, y)
+ self.assertAllEqual(sw_out, sw)
+
+ def test_error_multi_segment_input(self):
+ input_data = {
+ "encoder_text": (
+ tf.constant([" airplane at airport"] * 2),
+ tf.constant([" airplane"] * 2),
+ ),
+ "decoder_text": (
+ tf.constant([" kohli is the best"] * 2),
+ tf.constant([" kohli"] * 2),
+ ),
+ }
+
+ with self.assertRaises(ValueError):
+ self.preprocessor(input_data)
+
+ def test_serialization(self):
+ new_preprocessor = keras.utils.deserialize_keras_object(
+ keras.utils.serialize_keras_object(self.preprocessor)
+ )
+ self.assertEqual(
+ new_preprocessor.get_config(), self.preprocessor.get_config()
+ )
+
+ @parameterized.named_parameters(
+ ("tf_format", "tf", "model"),
+ ("keras_format", "keras_v3", "model.keras"),
+ )
+ @pytest.mark.large
+ def test_saved_model(self, save_format, filename):
+ input_data = {
+ "encoder_text": tf.constant(" airplane at airport"),
+ "decoder_text": tf.constant(" kohli is the best"),
+ }
+
+ inputs = {
+ "encoder_text": keras.Input(dtype="string", shape=()),
+ "decoder_text": keras.Input(dtype="string", shape=()),
+ }
+ outputs = self.preprocessor(inputs)
+ model = keras.Model(inputs=inputs, outputs=outputs)
+
+ path = os.path.join(self.get_temp_dir(), filename)
+ # Don't save traces in the tf format, we check compilation elsewhere.
+ kwargs = {"save_traces": False} if save_format == "tf" else {}
+ model.save(path, save_format=save_format, **kwargs)
+
+ restored_model = keras.models.load_model(path)
+
+ model_output = model(input_data)
+ restored_model_output = restored_model(input_data)
+
+ self.assertAllClose(
+ model_output,
+ restored_model_output,
+ )
diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py
new file mode 100644
index 0000000000..41ca5e61e5
--- /dev/null
+++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py
@@ -0,0 +1,187 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BART Seq2Seq LM preprocessor layer."""
+
+from absl import logging
+
+from keras_nlp.api_export import keras_nlp_export
+from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor
+from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
+
+
+@keras_nlp_export("keras_nlp.models.BartSeq2SeqLMPreprocessor")
+class BartSeq2SeqLMPreprocessor(BartPreprocessor):
+ """BART Seq2Seq LM preprocessor.
+
+ This layer is used as preprocessor for seq2seq tasks using the BART model.
+ This class subclasses `keras_nlp.models.BartPreprocessor` and keeps most of
+ its functionality. It has two changes from the superclass:
+
+ 1. Sets the `y` (label) and `sample_weights` fields by shifting the
+ decoder input sequence one step towards the left. Both these fields are
+ inferred internally, and any passed values will be ignored.
+ 2. Drops the last token from the decoder input sequence as it does not have
+ a successor.
+
+ Args:
+ tokenizer: A `keras_nlp.models.BartTokenizer` instance.
+ encoder_sequence_length: The length of the packed encoder inputs.
+ decoder_sequence_length: The length of the packed decoder inputs.
+ truncate: string. The algorithm to truncate a list of batched segments
+ to fit within `sequence_length`. The value can be either
+ `round_robin` or `waterfall`:
+ - `"round_robin"`: Available space is assigned one token at a
+ time in a round-robin fashion to the inputs that still need
+ some, until the limit is reached.
+ - `"waterfall"`: The allocation of the budget is done using a
+ "waterfall" algorithm that allocates quota in a
+ left-to-right manner and fills up the buckets until we run
+ out of budget. It supports an arbitrary number of segments.
+
+ Call arguments:
+ x: A dictionary with `encoder_text` and `decoder_text` as its keys.
+ Each value in the dictionary should be a tensor of single string
+ sequences. Inputs may be batched or unbatched. Raw python inputs
+ will be converted to tensors.
+ y: Label data. Should always be `None` as the layer generates labels by
+ shifting the decoder input sequence one step to the left.
+ sample_weight: Label weights. Should always be `None` as the layer
+ generates label weights by shifting the padding mask one step to the
+ left.
+
+ Examples:
+
+ Directly calling the layer on data
+ ```python
+ preprocessor = keras_nlp.models.BartPreprocessor.from_preset("bart_base_en")
+
+ # Preprocess unbatched inputs.
+ inputs = {
+ "encoder_text": "The fox was sleeping.",
+ "decoder_text": "The fox was awake."
+ }
+ preprocessor(inputs)
+
+ # Preprocess batched inputs.
+ inputs = {
+ "encoder_text": ["The fox was sleeping.", "The lion was quiet."],
+ "decoder_text": ["The fox was awake.", "The lion was roaring."]
+ }
+ preprocessor(inputs)
+
+ # Custom vocabulary.
+ vocab = {
+ "": 0,
+ "": 1,
+ "": 2,
+ "Ġafter": 5,
+ "noon": 6,
+ "Ġsun": 7,
+ }
+ merges = ["Ġ a", "Ġ s", "Ġ n", "e r", "n o", "o n", "Ġs u", "Ġa f", "no on"]
+ merges += ["Ġsu n", "Ġaf t", "Ġaft er"]
+
+ tokenizer = keras_nlp.models.BartTokenizer(
+ vocabulary=vocab,
+ merges=merges,
+ )
+ preprocessor = keras_nlp.models.BartPreprocessor(
+ tokenizer=tokenizer,
+ encoder_sequence_length=20,
+ decoder_sequence_length=10,
+ )
+ inputs = {
+ "encoder_text": "The fox was sleeping.",
+ "decoder_text": "The fox was awake."
+ }
+ preprocessor(inputs)
+ ```
+
+ Mapping with `tf.data.Dataset`.
+ ```python
+ preprocessor = keras_nlp.models.BartPreprocessor.from_preset("bart_base_en")
+
+ # Map single sentences.
+ features = {
+ "encoder_text": tf.constant(
+ ["The fox was sleeping.", "The lion was quiet."]
+ ),
+ "decoder_text": tf.constant(
+ ["The fox was awake.", "The lion was roaring."]
+ )
+ }
+ ds = tf.data.Dataset.from_tensor_slices(features)
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
+ ```
+ """
+
+ def __init__(
+ self,
+ tokenizer,
+ encoder_sequence_length,
+ decoder_sequence_length,
+ truncate="round_robin",
+ **kwargs
+ ):
+ # Since we truncate the last token from `decoder_token_ids`, we need to
+ # forcefully set the `decoder_sequence_length` to one greater than the
+ # value passed.
+ super().__init__(
+ tokenizer=tokenizer,
+ encoder_sequence_length=encoder_sequence_length,
+ decoder_sequence_length=decoder_sequence_length + 1,
+ truncate=truncate,
+ **kwargs
+ )
+
+ # Maintain a private copy of `decoder_sequence_length` for config
+ # purposes.
+ self._decoder_sequence_length = decoder_sequence_length
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "encoder_sequence_length": self.encoder_packer.sequence_length,
+ "decoder_sequence_length": self._decoder_sequence_length,
+ "truncate": self.encoder_packer.truncate,
+ }
+ )
+ return config
+
+ def call(self, x, y=None, sample_weight=None):
+ if y is not None or sample_weight is not None:
+ logging.warning(
+ "`BartSeq2SeqLMPreprocessor` infers `y` and `sample_weight` "
+ "from the provided input data, i.e., `x`. However, non-`None`"
+ "values have been passed for `y` or `sample_weight` or both. "
+ "These values will be ignored."
+ )
+
+ x = super().call(x)
+ decoder_token_ids = x.pop("decoder_token_ids")
+ decoder_padding_mask = x.pop("decoder_padding_mask")
+
+ # The last token does not have a next token. Hence, we truncate it.
+ x = {
+ **x,
+ "decoder_token_ids": decoder_token_ids[..., :-1],
+ "decoder_padding_mask": decoder_padding_mask[..., :-1],
+ }
+ # Target `y` will be the decoder input sequence shifted one step to the
+ # left (i.e., the next token).
+ y = decoder_token_ids[..., 1:]
+ sample_weight = decoder_padding_mask[..., 1:]
+ return pack_x_y_sample_weight(x, y, sample_weight)
diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py
new file mode 100644
index 0000000000..dd6ecfde91
--- /dev/null
+++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py
@@ -0,0 +1,161 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for BART preprocessor layer."""
+
+import os
+
+import pytest
+import tensorflow as tf
+from absl.testing import parameterized
+from tensorflow import keras
+
+from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import (
+ BartSeq2SeqLMPreprocessor,
+)
+from keras_nlp.models.bart.bart_tokenizer import BartTokenizer
+
+
+class BartSeq2SeqLMPreprocessorTest(tf.test.TestCase, parameterized.TestCase):
+ def setUp(self):
+ vocab = {
+ "": 0,
+ "": 1,
+ "": 2,
+ "Ġair": 3,
+ "plane": 4,
+ "Ġat": 5,
+ "port": 6,
+ "Ġkoh": 7,
+ "li": 8,
+ "Ġis": 9,
+ "Ġthe": 10,
+ "Ġbest": 11,
+ "": 12,
+ }
+
+ merges = ["Ġ a", "Ġ t", "Ġ k", "Ġ i", "Ġ b", "Ġa i", "p l", "n e"]
+ merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"]
+ merges += ["Ġt h", "Ġai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"]
+ merges += ["pla ne"]
+
+ self.preprocessor = BartSeq2SeqLMPreprocessor(
+ tokenizer=BartTokenizer(
+ vocabulary=vocab,
+ merges=merges,
+ ),
+ encoder_sequence_length=10,
+ decoder_sequence_length=8,
+ )
+
+ def test_tokenize_strings(self):
+ input_data = {
+ "encoder_text": " airplane at airport",
+ "decoder_text": " kohli is the best",
+ }
+
+ x_out, y_out, sw_out = self.preprocessor(input_data)
+ self.assertAllEqual(
+ x_out["encoder_token_ids"], [0, 3, 4, 5, 3, 6, 2, 1, 1, 1]
+ )
+ self.assertAllEqual(
+ x_out["encoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
+ )
+ self.assertAllEqual(
+ x_out["decoder_token_ids"], [0, 7, 8, 9, 10, 11, 2, 1]
+ )
+ self.assertAllEqual(
+ x_out["decoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0]
+ )
+ self.assertAllEqual(y_out, [7, 8, 9, 10, 11, 2, 1, 1])
+ self.assertAllEqual(sw_out, [1, 1, 1, 1, 1, 1, 0, 0])
+
+ def test_tokenize_list_of_strings(self):
+ input_data = {
+ "encoder_text": [" airplane at airport"] * 4,
+ "decoder_text": [" kohli is the best"] * 4,
+ }
+
+ x_out, y_out, sw_out = self.preprocessor(input_data)
+ self.assertAllEqual(
+ x_out["encoder_token_ids"], [[0, 3, 4, 5, 3, 6, 2, 1, 1, 1]] * 4
+ )
+ self.assertAllEqual(
+ x_out["encoder_padding_mask"],
+ [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4,
+ )
+ self.assertAllEqual(
+ x_out["decoder_token_ids"], [[0, 7, 8, 9, 10, 11, 2, 1]] * 4
+ )
+ self.assertAllEqual(
+ x_out["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4
+ )
+ self.assertAllEqual(y_out, [[7, 8, 9, 10, 11, 2, 1, 1]] * 4)
+ self.assertAllEqual(sw_out, [[1, 1, 1, 1, 1, 1, 0, 0]] * 4)
+
+ def test_error_multi_segment_input(self):
+ input_data = {
+ "encoder_text": (
+ tf.constant([" airplane at airport"] * 2),
+ tf.constant([" airplane"] * 2),
+ ),
+ "decoder_text": (
+ tf.constant([" kohli is the best"] * 2),
+ tf.constant([" kohli"] * 2),
+ ),
+ }
+
+ with self.assertRaises(ValueError):
+ self.preprocessor(input_data)
+
+ def test_serialization(self):
+ new_preprocessor = keras.utils.deserialize_keras_object(
+ keras.utils.serialize_keras_object(self.preprocessor)
+ )
+ self.assertEqual(
+ new_preprocessor.get_config(), self.preprocessor.get_config()
+ )
+
+ @parameterized.named_parameters(
+ ("tf_format", "tf", "model"),
+ ("keras_format", "keras_v3", "model.keras"),
+ )
+ @pytest.mark.large
+ def test_saved_model(self, save_format, filename):
+ input_data = {
+ "encoder_text": tf.constant(" airplane at airport"),
+ "decoder_text": tf.constant(" kohli is the best"),
+ }
+
+ inputs = {
+ "encoder_text": keras.Input(dtype="string", shape=()),
+ "decoder_text": keras.Input(dtype="string", shape=()),
+ }
+ outputs = self.preprocessor(inputs)
+ model = keras.Model(inputs=inputs, outputs=outputs)
+
+ path = os.path.join(self.get_temp_dir(), filename)
+ # Don't save traces in the tf format, we check compilation elsewhere.
+ kwargs = {"save_traces": False} if save_format == "tf" else {}
+ model.save(path, save_format=save_format, **kwargs)
+
+ restored_model = keras.models.load_model(path)
+
+ model_output = model(input_data)
+ restored_model_output = restored_model(input_data)
+
+ self.assertAllClose(
+ model_output,
+ restored_model_output,
+ )