Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 253 additions & 0 deletions keras_nlp/models/bart/bart_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BART preprocessor layer."""

import copy

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
from keras_nlp.models.bart.bart_presets import backbone_presets
from keras_nlp.models.bart.bart_tokenizer import BartTokenizer
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import (
convert_inputs_to_list_of_tensor_segments,
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
from keras_nlp.utils.python_utils import classproperty


@keras_nlp_export("keras_nlp.models.BartPreprocessor")
class BartPreprocessor(Preprocessor):
"""A BART preprocessing layer which tokenizes and packs inputs.

This preprocessing layer will do three things:

- Tokenize both encoder inputs and decoder inputs using the `tokenizer`.
Both inputs can contain any number of segments.
- Pack the inputs together using `keras_nlp.layers.MultiSegmentPacker` with
the appropriate special tokens - `"<s>"`, `"</s>"` and `"<pad>"`.
- Construct a dictionary with keys `"encoder_token_ids"`,
`"encoder_padding_mask"`, `"decoder_token_ids"`, `"decoder_padding_mask"`
that can be passed directly to a BART model.

This layer can be used directly with `tf.data.Dataset.map` to preprocess
string data in the `(x, y, sample_weight)` format used by
`keras.Model.fit`.

The call method of this layer accepts three arguments, `x`, `y`, and
`sample_weight`. `x` should be python dictionary, having "encoder_inputs"
and "decoder_inputs" as its keys. Each value in the dictionary can be a
python string or tensor representing a single segment, a list of python
strings representing a batch of single segments, or a list of tensors
representing multiple segments to be packed together. `y` and `sample_weight`
are both optional, can have any format, and will be passed through unaltered.

Args:
tokenizer: A `keras_nlp.models.BartTokenizer` instance.
sequence_length: The length of the packed inputs.
truncate: string. The algorithm to truncate a list of batched segments
to fit within `sequence_length`. The value can be either
`round_robin` or `waterfall`:
- `"round_robin"`: Available space is assigned one token at a
time in a round-robin fashion to the inputs that still need
some, until the limit is reached.
- `"waterfall"`: The allocation of the budget is done using a
"waterfall" algorithm that allocates quota in a
left-to-right manner and fills up the buckets until we run
out of budget. It supports an arbitrary number of segments.

Examples:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rework all of these pull requests to match the style here #843

```python
# Load the preprocessor from a preset.
preprocessor = keras_nlp.models.BartPreprocessor.from_preset("bart_base_en")

# Tokenize and pack a single sentence.
inputs = {
"encoder_inputs": "The fox was sleeping.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Open question...

Should we call this "encoder_text" to better accommodate "encoder_audio" for whisper? Or will it be simpler to have the same names everywhere. I somewhat like the self documenting property of saying this is text input.

Copy link
Collaborator Author

@abheesht17 abheesht17 Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, "encoder_text" and "decoder_text" sound good to me!

"decoder_inputs": "The fox was awake."
}
preprocessor(inputs)
# Same output.
inputs = {
"encoder_inputs": tf.constant("The fox was sleeping."),
"decoder_inputs": tf.constant("The fox was awake.")
}
preprocessor(inputs)

# Tokenize a batch of single sentences.
inputs = {
"encoder_inputs": ["The fox was sleeping.", "The lion was quiet."],
"decoder_inputs": ["The fox was awake.", "The lion was roaring."]
}
preprocessor(inputs)
# Same output.
inputs = {
"encoder_inputs": tf.constant(
["The fox was sleeping.", "The lion was quiet."]
),
"decoder_inputs": tf.constant(
["The fox was awake.", "The lion was roaring."]
)
}
preprocessor(inputs)

# Tokenize and pack a sentence pair.
inputs = {
"encoder_inputs": (
Copy link
Member

@mattdangerw mattdangerw Mar 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking over this more, let's keep it simple on the first attempt, and have no support for multiple segments in the base preprocessor layer for now. This will fit with GPT2 code.

IMO this is still just too complicated, and I not sure the use case. For classification, we can support multiple segments, but I don't see the huge need for multiple segments with separate encoder and decoder inputs. Do we have a clear use case there we want to support?

If not, let's land this with the simpler feature set.

tf.constant("The fox was sleeping."),
tf.constant("The lion was quiet.")
),
"decoder_inputs": (
tf.constant("The fox was awake."),
tf.constant("The lion was roaring.")
)
}
preprocessor(inputs)

# Map a dataset to preprocess a single sentence.
features = {
"encoder_inputs": tf.constant(
["The fox was sleeping.", "The lion was quiet."]
),
"decoder_inputs": tf.constant(
["The fox was awake.", "The lion was roaring."]
)
}
ds = tf.data.Dataset.from_tensor_slices(features)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)

# Map a dataset to preprocess sentence pairs.
features = {
"encoder_inputs": (
tf.constant(
["The fox was sleeping.", "The lion was quiet."]
),
tf.constant(
["It wanted to get up.", "It wanted to roar."]
),
),
"decoder_inputs": (
tf.constant(
["The fox was sleeping.", "The lion was quiet."]
),
tf.constant(
["It wanted to get up.", "It wanted to roar."]
),
),
}
labels = tf.constant([0, 1])
ds = tf.data.Dataset.from_tensor_slices(
(
features, labels
)
)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)

# Alternatively, you can create a preprocessor from your own vocabulary.
# The usage is exactly the same as above.
vocab = {
"<s>": 0,
"<pad>": 1,
"</s>": 2,
"Ġafter": 5,
"noon": 6,
"Ġsun": 7,
}
merges = ["Ġ a", "Ġ s", "Ġ n", "e r", "n o", "o n", "Ġs u", "Ġa f", "no on"]
merges += ["Ġsu n", "Ġaf t", "Ġaft er"]

tokenizer = keras_nlp.models.BartTokenizer(
vocabulary=vocab,
merges=merges,
)
preprocessor = keras_nlp.models.BartPreprocessor(
tokenizer=tokenizer,
sequence_length=20,
)
```
"""

def __init__(
self,
tokenizer,
sequence_length=1024,
truncate="round_robin",
**kwargs,
):
super().__init__(**kwargs)
self.tokenizer = tokenizer

# TODO: Allow users to pass separate `sequence_length`s for encoder and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make an issue for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved it in this PR itself.

# decoder.
self.packer = MultiSegmentPacker(
start_value=self.tokenizer.start_token_id,
end_value=self.tokenizer.end_token_id,
pad_value=self.tokenizer.pad_token_id,
truncate=truncate,
sequence_length=sequence_length,
)

def get_config(self):
config = super().get_config()
config.update(
{
"sequence_length": self.packer.sequence_length,
"truncate": self.packer.truncate,
}
)
return config

def call(self, x, y=None, sample_weight=None):
if not (
isinstance(x, dict)
and ["encoder_inputs", "decoder_inputs"] == list(x.keys())
):
raise ValueError(
'`x` must be a dictionary, containing the keys `"encoder_inputs"`'
f' and `"decoder_inputs"`. Received x={x}.'
)

encoder_inputs = x["encoder_inputs"]
decoder_inputs = x["decoder_inputs"]

encoder_inputs = convert_inputs_to_list_of_tensor_segments(
encoder_inputs
)
encoder_inputs = [self.tokenizer(segment) for segment in encoder_inputs]
encoder_token_ids, _ = self.packer(encoder_inputs)

decoder_inputs = convert_inputs_to_list_of_tensor_segments(
decoder_inputs
)
decoder_inputs = [self.tokenizer(segment) for segment in decoder_inputs]
decoder_token_ids, _ = self.packer(decoder_inputs)

x = {
"encoder_token_ids": encoder_token_ids,
"encoder_padding_mask": encoder_token_ids
!= self.tokenizer.pad_token_id,
"decoder_token_ids": decoder_token_ids,
"decoder_padding_mask": decoder_token_ids
!= self.tokenizer.pad_token_id,
}

return pack_x_y_sample_weight(x, y, sample_weight)

@classproperty
def tokenizer_cls(cls):
return BartTokenizer

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
136 changes: 136 additions & 0 deletions keras_nlp/models/bart/bart_seq_2_seq_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BART Seq2Seq preprocessor layer."""

from absl import logging

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@keras_nlp_export("keras_nlp.models.BartSeq2SeqPreprocessor")
class BartSeq2SeqPreprocessor(BartPreprocessor):
"""BART Seq2Seq preprocessor.

This layer is used as preprocessor for seq2seq tasks using the BART model.
This class subclasses `keras_nlp.models.BartPreprocessor` and keeps most of
its functionality. It has two changes from the superclass:

- Sets the `y` (label) and `sample_weights` fields by shifting the
decoder input sequence one step towards the left. Both these fields are
inferred internally, and any passed values will be ignored.
- Drops the last token from the decoder input sequence as it does not have
a successor.

Args:
tokenizer: A `keras_nlp.models.BartTokenizer` instance.
sequence_length: The length of the packed inputs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth mentioning that this is the length for both encoder and decoder sequences (for now).


Examples:
```python
# Load the preprocessor from a preset.
preprocessor = keras_nlp.models.BartSeq2SeqPreprocessor.from_preset("bart_base_en")

# Tokenize and pack a single sentence.
inputs = {
"encoder_inputs": "The fox was sleeping.",
"decoder_inputs": "The fox was awake."
}
preprocessor(inputs)
# Same output.
inputs = {
"encoder_inputs": tf.constant("The fox was sleeping."),
"decoder_inputs": tf.constant("The fox was awake.")
}
preprocessor(inputs)

# Tokenize a batch of single sentences.
inputs = {
"encoder_inputs": ["The fox was sleeping.", "The lion was quiet."],
"decoder_inputs": ["The fox was awake.", "The lion was roaring."]
}
preprocessor(inputs)
# Same output.
inputs = {
"encoder_inputs": tf.constant(
["The fox was sleeping.", "The lion was quiet."]
),
"decoder_inputs": tf.constant(
["The fox was awake.", "The lion was roaring."]
)
}
preprocessor(inputs)

# Map a dataset to preprocess a single sentence.
features = {
"encoder_inputs": tf.constant(
["The fox was sleeping.", "The lion was quiet."]
),
"decoder_inputs": tf.constant(
["The fox was awake.", "The lion was roaring."]
)
}
ds = tf.data.Dataset.from_tensor_slices(features)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)

# Alternatively, you can create a preprocessor from your own vocabulary.
# The usage is exactly the same as above.
vocab = {
"<s>": 0,
"<pad>": 1,
"</s>": 2,
"Ġafter": 5,
"noon": 6,
"Ġsun": 7,
}
merges = ["Ġ a", "Ġ s", "Ġ n", "e r", "n o", "o n", "Ġs u", "Ġa f", "no on"]
merges += ["Ġsu n", "Ġaf t", "Ġaft er"]

tokenizer = keras_nlp.models.BartTokenizer(
vocabulary=vocab,
merges=merges,
)
preprocessor = keras_nlp.models.BartSeq2SeqPreprocessor(
tokenizer=tokenizer,
sequence_length=20,
)
```
"""

def call(self, x, y=None, sample_weight=None):
if y is not None or sample_weight is not None:
logging.warning(
"`BartSeq2SeqPreprocessor` infers `y` and `sample_weight` "
"from the provided input data, i.e., `x`. However, non-`None`"
"values have been passed for `y` or `sample_weight` or both. "
"These values will be ignored."
)

x = super().call(x)
decoder_token_ids = x.pop("decoder_token_ids")
decoder_padding_mask = x.pop("decoder_padding_mask")

# The last token does not have a next token. Hence, we truncate it.
x = {
**x,
"decoder_token_ids": decoder_token_ids[..., :-1],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this actually work as we want? I think this will generate an encoder sequence with length sequence_length but a decoder sequence with length sequence_length - 1.

We want both to have both feature sequence have the same length I think, which means we have to tokenize the encoder sequence with length sequence_length and the decoder with length sequence_length + 1 before the feature label offsetting.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw - in that case, we'll need to define two MultiSegmentPackers. Might as well work on #904 in this PR itself instead of saving it for later?

"decoder_padding_mask": decoder_padding_mask[..., :-1],
}
# Target `y` will be the decoder input sequence shifted one step to the
# left (i.e., the next token).
y = decoder_token_ids[..., 1:]
sample_weight = decoder_padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)