Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from keras_nlp.models.distil_bert.distil_bert_tokenizer import (
DistilBertTokenizer,
)
from keras_nlp.models.fnet.fnet_preprocessor import FNetPreprocessor
from keras_nlp.models.fnet.fnet_tokenizer import FNetTokenizer
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
Expand Down
13 changes: 13 additions & 0 deletions keras_nlp/models/fnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
180 changes: 180 additions & 0 deletions keras_nlp/models/fnet/fnet_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FNet preprocessor layer."""

from tensorflow import keras

from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
from keras_nlp.models.fnet.fnet_tokenizer import FNetTokenizer
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import (
convert_inputs_to_list_of_tensor_segments,
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class FNetPreprocessor(Preprocessor):
"""An FNet preprocessing layer which tokenizes and packs inputs.

This preprocessing layer will do three things:

- Tokenize any number of input segments using the `tokenizer`.
- Pack the inputs together using a `keras_nlp.layers.MultiSegmentPacker`.
with the appropriate `"[CLS]"`, `"[SEP]"` and `"<pad>"` tokens.
- Construct a dictionary with keys `"token_ids"`, `"segment_ids"` and
`"padding_mask"`, that can be passed directly to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no "padding_mask" right?

`keras_nlp.models.FNetBackbone`.

This layer can be used directly with `tf.data.Dataset.map` to preprocess
string data in the `(x, y, sample_weight)` format used by
`keras.Model.fit`.

The call method of this layer accepts three arguments, `x`, `y`, and
`sample_weight`. `x` can be a python string or tensor representing a single
segment, a list of python strings representing a batch of single segments,
or a list of tensors representing multiple segments to be packed together.
`y` and `sample_weight` are both optional, can have any format, and will be
passed through unaltered.

Special care should be taken when using `tf.data` to map over an unlabeled
tuple of string segments. `tf.data.Dataset.map` will unpack this tuple
directly into the call arguments of this layer, rather than forward all
argument to `x`. To handle this case, it is recommended to explicitly call
the layer, e.g. `ds.map(lambda seg1, seg2: preprocessor(x=(seg1, seg2)))`.

Args:
tokenizer: A `keras_nlp.models.FNetTokenizer` instance.
sequence_length: The length of the packed inputs.
truncate: string. The algorithm to truncate a list of batched segments
to fit within `sequence_length`. The value can be either
`round_robin` or `waterfall`:
- `"round_robin"`: Available space is assigned one token at a
time in a round-robin fashion to the inputs that still need
some, until the limit is reached.
- `"waterfall"`: The allocation of the budget is done using a
"waterfall" algorithm that allocates quota in a
left-to-right manner and fills up the buckets until we run
out of budget. It supports an arbitrary number of segments.

Examples:
```python
tokenizer = keras_nlp.models.FNetTokenizer(proto="model.spm")
preprocessor = keras_nlp.models.FNetPreprocessor(
tokenizer=tokenizer,
sequence_length=10,
)

# Tokenize and pack a single sentence.
sentence = tf.constant("The quick brown fox jumped.")
preprocessor(sentence)
# Same output.
preprocessor("The quick brown fox jumped.")

# Tokenize and a batch of single sentences.
sentences = tf.constant(
["The quick brown fox jumped.", "Call me Ishmael."]
)
preprocessor(sentences)
# Same output.
preprocessor(
["The quick brown fox jumped.", "Call me Ishmael."]
)

# Tokenize and pack a sentence pair.
first_sentence = tf.constant("The quick brown fox jumped.")
second_sentence = tf.constant("The fox tripped.")
preprocessor((first_sentence, second_sentence))

# Map a dataset to preprocess a single sentence.
features = tf.constant(
["The quick brown fox jumped.", "Call me Ishmael."]
)
labels = tf.constant([0, 1])
ds = tf.data.Dataset.from_tensor_slices((features, labels))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)

# Map a dataset to preprocess sentence pairs.
first_sentences = tf.constant(
["The quick brown fox jumped.", "Call me Ishmael."]
)
second_sentences = tf.constant(
["The fox tripped.", "Oh look, a whale."]
)
labels = tf.constant([1, 1])
ds = tf.data.Dataset.from_tensor_slices(
(
(first_sentences, second_sentences), labels
)
)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)

# Map a dataset to preprocess unlabeled sentence pairs.
first_sentences = tf.constant(
["The quick brown fox jumped.", "Call me Ishmael."]
)
second_sentences = tf.constant(
["The fox tripped.", "Oh look, a whale."]
)
ds = tf.data.Dataset.from_tensor_slices((first_sentences, second_sentences))
# Watch out for tf.data's default unpacking of tuples here!
# Best to invoke the `preprocessor` directly in this case.
ds = ds.map(
lambda s1, s2: preprocessor(x=(s1, s2)),
num_parallel_calls=tf.data.AUTOTUNE,
)
```
"""

def __init__(
self,
tokenizer,
sequence_length=512,
truncate="round_robin",
**kwargs,
):
super().__init__(**kwargs)
self._tokenizer = tokenizer
self.packer = MultiSegmentPacker(
start_value=self.tokenizer.cls_token_id,
end_value=self.tokenizer.sep_token_id,
pad_value=self.tokenizer.pad_token_id,
truncate=truncate,
sequence_length=sequence_length,
)

def get_config(self):
config = super().get_config()
config.update(
{
"sequence_length": self.packer.sequence_length,
"truncate": self.packer.truncate,
}
)
return config

def call(self, x, y=None, sample_weight=None):
x = convert_inputs_to_list_of_tensor_segments(x)
x = [self.tokenizer(segment) for segment in x]
token_ids, segment_ids = self.packer(x)
x = {
"token_ids": token_ids,
"segment_ids": segment_ids,
}
return pack_x_y_sample_weight(x, y, sample_weight)

@classproperty
def tokenizer_cls(cls):
return FNetTokenizer
156 changes: 156 additions & 0 deletions keras_nlp/models/fnet/fnet_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for FNet preprocessor layer."""

import io
import os

import sentencepiece
import tensorflow as tf
from absl.testing import parameterized
from tensorflow import keras

from keras_nlp.models.fnet.fnet_preprocessor import FNetPreprocessor
from keras_nlp.models.fnet.fnet_tokenizer import FNetTokenizer


class FNetPreprocessorTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
bytes_io = io.BytesIO()
vocab_data = tf.data.Dataset.from_tensor_slices(
["the quick brown fox", "the earth is round"]
)
sentencepiece.SentencePieceTrainer.train(
sentence_iterator=vocab_data.as_numpy_iterator(),
model_writer=bytes_io,
vocab_size=10,
model_type="WORD",
pad_id=3,
unk_id=0,
bos_id=4,
eos_id=5,
pad_piece="<pad>",
unk_piece="<unk>",
bos_piece="[CLS]",
eos_piece="[SEP]",
)
self.proto = bytes_io.getvalue()

self.preprocessor = FNetPreprocessor(
tokenizer=FNetTokenizer(proto=self.proto),
sequence_length=12,
)

def test_tokenize_strings(self):
input_data = "the quick brown fox"
output = self.preprocessor(input_data)
self.assertAllEqual(
output["token_ids"], [4, 1, 9, 2, 7, 5, 3, 3, 3, 3, 3, 3]
)
self.assertAllEqual(
output["segment_ids"], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
)

def test_tokenize_list_of_strings(self):
# We should handle a list of strings as as batch.
input_data = ["the quick brown fox"] * 4
output = self.preprocessor(input_data)
self.assertAllEqual(
output["token_ids"],
[[4, 1, 9, 2, 7, 5, 3, 3, 3, 3, 3, 3]] * 4,
)
self.assertAllEqual(
output["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4
)

def test_tokenize_labeled_batch(self):
x = tf.constant(["the quick brown fox"] * 4)
y = tf.constant([1] * 4)
sw = tf.constant([1.0] * 4)
x_out, y_out, sw_out = self.preprocessor(x, y, sw)
self.assertAllEqual(
x_out["token_ids"],
[[4, 1, 9, 2, 7, 5, 3, 3, 3, 3, 3, 3]] * 4,
)
self.assertAllEqual(
x_out["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4
)
self.assertAllEqual(y_out, y)
self.assertAllEqual(sw_out, sw)

def test_tokenize_labeled_dataset(self):
x = tf.constant(["the quick brown fox"] * 4)
y = tf.constant([1] * 4)
sw = tf.constant([1.0] * 4)
ds = tf.data.Dataset.from_tensor_slices((x, y, sw))
ds = ds.map(self.preprocessor)
x_out, y_out, sw_out = ds.batch(4).take(1).get_single_element()
self.assertAllEqual(
x_out["token_ids"],
[[4, 1, 9, 2, 7, 5, 3, 3, 3, 3, 3, 3]] * 4,
)
self.assertAllEqual(
x_out["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4
)
self.assertAllEqual(y_out, y)
self.assertAllEqual(sw_out, sw)

def test_tokenize_multiple_sentences(self):
sentence_one = tf.constant("the quick brown fox")
sentence_two = tf.constant("the earth")
output = self.preprocessor((sentence_one, sentence_two))
self.assertAllEqual(
output["token_ids"],
[4, 1, 9, 2, 7, 5, 1, 6, 5, 3, 3, 3],
)
self.assertAllEqual(
output["segment_ids"], [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0]
)

def test_tokenize_multiple_batched_sentences(self):
sentence_one = tf.constant(["the quick brown fox"] * 4)
sentence_two = tf.constant(["the earth"] * 4)
# The first tuple or list is always interpreted as an enumeration of
# separate sequences to concatenate.
output = self.preprocessor((sentence_one, sentence_two))
self.assertAllEqual(
output["token_ids"],
[[4, 1, 9, 2, 7, 5, 1, 6, 5, 3, 3, 3]] * 4,
)
self.assertAllEqual(
output["segment_ids"], [[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0]] * 4
)

def test_errors_for_2d_list_input(self):
ambiguous_input = [["one", "two"], ["three", "four"]]
with self.assertRaises(ValueError):
self.preprocessor(ambiguous_input)

@parameterized.named_parameters(
("tf_format", "tf", "model"),
("keras_format", "keras_v3", "model.keras"),
)
def test_saved_model(self, save_format, filename):
input_data = tf.constant(["the quick brown fox"])
inputs = keras.Input(dtype="string", shape=())
outputs = self.preprocessor(inputs)
model = keras.Model(inputs, outputs)
path = os.path.join(self.get_temp_dir(), filename)
model.save(path, save_format=save_format)
restored_model = keras.models.load_model(path)
self.assertAllEqual(
model(input_data)["token_ids"],
restored_model(input_data)["token_ids"],
)
Loading