Skip to content

Commit 3a951a9

Browse files
authored
Add FNet Preprocessor (#646)
* Add FNet Preprocessor * Add imports to __init__.py * Remove padding_mask ref * Change fnet to f_net
1 parent f7dfc7b commit 3a951a9

File tree

6 files changed

+534
-0
lines changed

6 files changed

+534
-0
lines changed

keras_nlp/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from keras_nlp.models.distil_bert.distil_bert_tokenizer import (
3030
DistilBertTokenizer,
3131
)
32+
from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor
33+
from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer
3234
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
3335
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
3436
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor

keras_nlp/models/f_net/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""FNet preprocessor layer."""
15+
16+
from tensorflow import keras
17+
18+
from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
19+
from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer
20+
from keras_nlp.models.preprocessor import Preprocessor
21+
from keras_nlp.utils.keras_utils import (
22+
convert_inputs_to_list_of_tensor_segments,
23+
)
24+
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
25+
from keras_nlp.utils.python_utils import classproperty
26+
27+
28+
@keras.utils.register_keras_serializable(package="keras_nlp")
29+
class FNetPreprocessor(Preprocessor):
30+
"""An FNet preprocessing layer which tokenizes and packs inputs.
31+
32+
This preprocessing layer will do three things:
33+
34+
- Tokenize any number of input segments using the `tokenizer`.
35+
- Pack the inputs together using a `keras_nlp.layers.MultiSegmentPacker`.
36+
with the appropriate `"[CLS]"`, `"[SEP]"` and `"<pad>"` tokens.
37+
- Construct a dictionary with keys `"token_ids"`, and `"segment_ids"` that
38+
can be passed directly to `keras_nlp.models.FNetBackbone`.
39+
40+
This layer can be used directly with `tf.data.Dataset.map` to preprocess
41+
string data in the `(x, y, sample_weight)` format used by
42+
`keras.Model.fit`.
43+
44+
The call method of this layer accepts three arguments, `x`, `y`, and
45+
`sample_weight`. `x` can be a python string or tensor representing a single
46+
segment, a list of python strings representing a batch of single segments,
47+
or a list of tensors representing multiple segments to be packed together.
48+
`y` and `sample_weight` are both optional, can have any format, and will be
49+
passed through unaltered.
50+
51+
Special care should be taken when using `tf.data` to map over an unlabeled
52+
tuple of string segments. `tf.data.Dataset.map` will unpack this tuple
53+
directly into the call arguments of this layer, rather than forward all
54+
argument to `x`. To handle this case, it is recommended to explicitly call
55+
the layer, e.g. `ds.map(lambda seg1, seg2: preprocessor(x=(seg1, seg2)))`.
56+
57+
Args:
58+
tokenizer: A `keras_nlp.models.FNetTokenizer` instance.
59+
sequence_length: The length of the packed inputs.
60+
truncate: string. The algorithm to truncate a list of batched segments
61+
to fit within `sequence_length`. The value can be either
62+
`round_robin` or `waterfall`:
63+
- `"round_robin"`: Available space is assigned one token at a
64+
time in a round-robin fashion to the inputs that still need
65+
some, until the limit is reached.
66+
- `"waterfall"`: The allocation of the budget is done using a
67+
"waterfall" algorithm that allocates quota in a
68+
left-to-right manner and fills up the buckets until we run
69+
out of budget. It supports an arbitrary number of segments.
70+
71+
Examples:
72+
```python
73+
tokenizer = keras_nlp.models.FNetTokenizer(proto="model.spm")
74+
preprocessor = keras_nlp.models.FNetPreprocessor(
75+
tokenizer=tokenizer,
76+
sequence_length=10,
77+
)
78+
79+
# Tokenize and pack a single sentence.
80+
sentence = tf.constant("The quick brown fox jumped.")
81+
preprocessor(sentence)
82+
# Same output.
83+
preprocessor("The quick brown fox jumped.")
84+
85+
# Tokenize and a batch of single sentences.
86+
sentences = tf.constant(
87+
["The quick brown fox jumped.", "Call me Ishmael."]
88+
)
89+
preprocessor(sentences)
90+
# Same output.
91+
preprocessor(
92+
["The quick brown fox jumped.", "Call me Ishmael."]
93+
)
94+
95+
# Tokenize and pack a sentence pair.
96+
first_sentence = tf.constant("The quick brown fox jumped.")
97+
second_sentence = tf.constant("The fox tripped.")
98+
preprocessor((first_sentence, second_sentence))
99+
100+
# Map a dataset to preprocess a single sentence.
101+
features = tf.constant(
102+
["The quick brown fox jumped.", "Call me Ishmael."]
103+
)
104+
labels = tf.constant([0, 1])
105+
ds = tf.data.Dataset.from_tensor_slices((features, labels))
106+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
107+
108+
# Map a dataset to preprocess sentence pairs.
109+
first_sentences = tf.constant(
110+
["The quick brown fox jumped.", "Call me Ishmael."]
111+
)
112+
second_sentences = tf.constant(
113+
["The fox tripped.", "Oh look, a whale."]
114+
)
115+
labels = tf.constant([1, 1])
116+
ds = tf.data.Dataset.from_tensor_slices(
117+
(
118+
(first_sentences, second_sentences), labels
119+
)
120+
)
121+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
122+
123+
# Map a dataset to preprocess unlabeled sentence pairs.
124+
first_sentences = tf.constant(
125+
["The quick brown fox jumped.", "Call me Ishmael."]
126+
)
127+
second_sentences = tf.constant(
128+
["The fox tripped.", "Oh look, a whale."]
129+
)
130+
ds = tf.data.Dataset.from_tensor_slices((first_sentences, second_sentences))
131+
# Watch out for tf.data's default unpacking of tuples here!
132+
# Best to invoke the `preprocessor` directly in this case.
133+
ds = ds.map(
134+
lambda s1, s2: preprocessor(x=(s1, s2)),
135+
num_parallel_calls=tf.data.AUTOTUNE,
136+
)
137+
```
138+
"""
139+
140+
def __init__(
141+
self,
142+
tokenizer,
143+
sequence_length=512,
144+
truncate="round_robin",
145+
**kwargs,
146+
):
147+
super().__init__(**kwargs)
148+
self._tokenizer = tokenizer
149+
self.packer = MultiSegmentPacker(
150+
start_value=self.tokenizer.cls_token_id,
151+
end_value=self.tokenizer.sep_token_id,
152+
pad_value=self.tokenizer.pad_token_id,
153+
truncate=truncate,
154+
sequence_length=sequence_length,
155+
)
156+
157+
def get_config(self):
158+
config = super().get_config()
159+
config.update(
160+
{
161+
"sequence_length": self.packer.sequence_length,
162+
"truncate": self.packer.truncate,
163+
}
164+
)
165+
return config
166+
167+
def call(self, x, y=None, sample_weight=None):
168+
x = convert_inputs_to_list_of_tensor_segments(x)
169+
x = [self.tokenizer(segment) for segment in x]
170+
token_ids, segment_ids = self.packer(x)
171+
x = {
172+
"token_ids": token_ids,
173+
"segment_ids": segment_ids,
174+
}
175+
return pack_x_y_sample_weight(x, y, sample_weight)
176+
177+
@classproperty
178+
def tokenizer_cls(cls):
179+
return FNetTokenizer
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for FNet preprocessor layer."""
16+
17+
import io
18+
import os
19+
20+
import sentencepiece
21+
import tensorflow as tf
22+
from absl.testing import parameterized
23+
from tensorflow import keras
24+
25+
from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor
26+
from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer
27+
28+
29+
class FNetPreprocessorTest(tf.test.TestCase, parameterized.TestCase):
30+
def setUp(self):
31+
bytes_io = io.BytesIO()
32+
vocab_data = tf.data.Dataset.from_tensor_slices(
33+
["the quick brown fox", "the earth is round"]
34+
)
35+
sentencepiece.SentencePieceTrainer.train(
36+
sentence_iterator=vocab_data.as_numpy_iterator(),
37+
model_writer=bytes_io,
38+
vocab_size=10,
39+
model_type="WORD",
40+
pad_id=3,
41+
unk_id=0,
42+
bos_id=4,
43+
eos_id=5,
44+
pad_piece="<pad>",
45+
unk_piece="<unk>",
46+
bos_piece="[CLS]",
47+
eos_piece="[SEP]",
48+
)
49+
self.proto = bytes_io.getvalue()
50+
51+
self.preprocessor = FNetPreprocessor(
52+
tokenizer=FNetTokenizer(proto=self.proto),
53+
sequence_length=12,
54+
)
55+
56+
def test_tokenize_strings(self):
57+
input_data = "the quick brown fox"
58+
output = self.preprocessor(input_data)
59+
self.assertAllEqual(
60+
output["token_ids"], [4, 1, 9, 2, 7, 5, 3, 3, 3, 3, 3, 3]
61+
)
62+
self.assertAllEqual(
63+
output["segment_ids"], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
64+
)
65+
66+
def test_tokenize_list_of_strings(self):
67+
# We should handle a list of strings as as batch.
68+
input_data = ["the quick brown fox"] * 4
69+
output = self.preprocessor(input_data)
70+
self.assertAllEqual(
71+
output["token_ids"],
72+
[[4, 1, 9, 2, 7, 5, 3, 3, 3, 3, 3, 3]] * 4,
73+
)
74+
self.assertAllEqual(
75+
output["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4
76+
)
77+
78+
def test_tokenize_labeled_batch(self):
79+
x = tf.constant(["the quick brown fox"] * 4)
80+
y = tf.constant([1] * 4)
81+
sw = tf.constant([1.0] * 4)
82+
x_out, y_out, sw_out = self.preprocessor(x, y, sw)
83+
self.assertAllEqual(
84+
x_out["token_ids"],
85+
[[4, 1, 9, 2, 7, 5, 3, 3, 3, 3, 3, 3]] * 4,
86+
)
87+
self.assertAllEqual(
88+
x_out["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4
89+
)
90+
self.assertAllEqual(y_out, y)
91+
self.assertAllEqual(sw_out, sw)
92+
93+
def test_tokenize_labeled_dataset(self):
94+
x = tf.constant(["the quick brown fox"] * 4)
95+
y = tf.constant([1] * 4)
96+
sw = tf.constant([1.0] * 4)
97+
ds = tf.data.Dataset.from_tensor_slices((x, y, sw))
98+
ds = ds.map(self.preprocessor)
99+
x_out, y_out, sw_out = ds.batch(4).take(1).get_single_element()
100+
self.assertAllEqual(
101+
x_out["token_ids"],
102+
[[4, 1, 9, 2, 7, 5, 3, 3, 3, 3, 3, 3]] * 4,
103+
)
104+
self.assertAllEqual(
105+
x_out["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4
106+
)
107+
self.assertAllEqual(y_out, y)
108+
self.assertAllEqual(sw_out, sw)
109+
110+
def test_tokenize_multiple_sentences(self):
111+
sentence_one = tf.constant("the quick brown fox")
112+
sentence_two = tf.constant("the earth")
113+
output = self.preprocessor((sentence_one, sentence_two))
114+
self.assertAllEqual(
115+
output["token_ids"],
116+
[4, 1, 9, 2, 7, 5, 1, 6, 5, 3, 3, 3],
117+
)
118+
self.assertAllEqual(
119+
output["segment_ids"], [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0]
120+
)
121+
122+
def test_tokenize_multiple_batched_sentences(self):
123+
sentence_one = tf.constant(["the quick brown fox"] * 4)
124+
sentence_two = tf.constant(["the earth"] * 4)
125+
# The first tuple or list is always interpreted as an enumeration of
126+
# separate sequences to concatenate.
127+
output = self.preprocessor((sentence_one, sentence_two))
128+
self.assertAllEqual(
129+
output["token_ids"],
130+
[[4, 1, 9, 2, 7, 5, 1, 6, 5, 3, 3, 3]] * 4,
131+
)
132+
self.assertAllEqual(
133+
output["segment_ids"], [[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0]] * 4
134+
)
135+
136+
def test_errors_for_2d_list_input(self):
137+
ambiguous_input = [["one", "two"], ["three", "four"]]
138+
with self.assertRaises(ValueError):
139+
self.preprocessor(ambiguous_input)
140+
141+
@parameterized.named_parameters(
142+
("tf_format", "tf", "model"),
143+
("keras_format", "keras_v3", "model.keras"),
144+
)
145+
def test_saved_model(self, save_format, filename):
146+
input_data = tf.constant(["the quick brown fox"])
147+
inputs = keras.Input(dtype="string", shape=())
148+
outputs = self.preprocessor(inputs)
149+
model = keras.Model(inputs, outputs)
150+
path = os.path.join(self.get_temp_dir(), filename)
151+
model.save(path, save_format=save_format)
152+
restored_model = keras.models.load_model(path)
153+
self.assertAllEqual(
154+
model(input_data)["token_ids"],
155+
restored_model(input_data)["token_ids"],
156+
)

0 commit comments

Comments
 (0)