Skip to content

Commit 802c7ef

Browse files
Add AlbertTokenizer and AlbertPreprocessor (#627)
* Add AlbertTokenizer and AlbertPreprocessor * Fix UTs * Fix UTs (1) * Update albert_preprocessor.py Co-authored-by: Matt Watson <[email protected]>
1 parent f54d24c commit 802c7ef

File tree

6 files changed

+574
-0
lines changed

6 files changed

+574
-0
lines changed

keras_nlp/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
16+
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
1517
from keras_nlp.models.bert.bert_backbone import BertBackbone
1618
from keras_nlp.models.bert.bert_classifier import BertClassifier
1719
from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""ALBERT preprocessor layer."""
15+
16+
from tensorflow import keras
17+
18+
from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
19+
from keras_nlp.utils.keras_utils import (
20+
convert_inputs_to_list_of_tensor_segments,
21+
)
22+
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
23+
from keras_nlp.utils.python_utils import classproperty
24+
25+
26+
@keras.utils.register_keras_serializable(package="keras_nlp")
27+
class AlbertPreprocessor(keras.layers.Layer):
28+
"""An ALBERT preprocessing layer which tokenizes and packs inputs.
29+
30+
This preprocessing layer will do three things:
31+
32+
- Tokenize any number of input segments using the `tokenizer`.
33+
- Pack the inputs together using a `keras_nlp.layers.MultiSegmentPacker`.
34+
with the appropriate `"[CLS]"`, `"[SEP]"` and `"<pad>"` tokens.
35+
- Construct a dictionary with keys `"token_ids"`, `"segment_ids"` and
36+
`"padding_mask"`, that can be passed directly to
37+
`keras_nlp.models.AlbertBackbone`.
38+
39+
This layer can be used directly with `tf.data.Dataset.map` to preprocess
40+
string data in the `(x, y, sample_weight)` format used by
41+
`keras.Model.fit`.
42+
43+
The call method of this layer accepts three arguments, `x`, `y`, and
44+
`sample_weight`. `x` can be a python string or tensor representing a single
45+
segment, a list of python strings representing a batch of single segments,
46+
or a list of tensors representing multiple segments to be packed together.
47+
`y` and `sample_weight` are both optional, can have any format, and will be
48+
passed through unaltered.
49+
50+
Special care should be taken when using `tf.data` to map over an unlabeled
51+
tuple of string segments. `tf.data.Dataset.map` will unpack this tuple
52+
directly into the call arguments of this layer, rather than forward all
53+
argument to `x`. To handle this case, it is recommended to explicitly call
54+
the layer, e.g. `ds.map(lambda seg1, seg2: preprocessor(x=(seg1, seg2)))`.
55+
56+
Args:
57+
tokenizer: A `keras_nlp.models.AlbertTokenizer` instance.
58+
sequence_length: The length of the packed inputs.
59+
truncate: string. The algorithm to truncate a list of batched segments
60+
to fit within `sequence_length`. The value can be either
61+
`round_robin` or `waterfall`:
62+
- `"round_robin"`: Available space is assigned one token at a
63+
time in a round-robin fashion to the inputs that still need
64+
some, until the limit is reached.
65+
- `"waterfall"`: The allocation of the budget is done using a
66+
"waterfall" algorithm that allocates quota in a
67+
left-to-right manner and fills up the buckets until we run
68+
out of budget. It supports an arbitrary number of segments.
69+
70+
Examples:
71+
```python
72+
tokenizer = keras_nlp.models.AlbertTokenizer(proto="model.spm")
73+
preprocessor = keras_nlp.models.AlbertPreprocessor(
74+
tokenizer=tokenizer,
75+
sequence_length=10,
76+
)
77+
78+
# Tokenize and pack a single sentence.
79+
sentence = tf.constant("The quick brown fox jumped.")
80+
preprocessor(sentence)
81+
# Same output.
82+
preprocessor("The quick brown fox jumped.")
83+
84+
# Tokenize and a batch of single sentences.
85+
sentences = tf.constant(
86+
["The quick brown fox jumped.", "Call me Ishmael."]
87+
)
88+
preprocessor(sentences)
89+
# Same output.
90+
preprocessor(
91+
["The quick brown fox jumped.", "Call me Ishmael."]
92+
)
93+
94+
# Tokenize and pack a sentence pair.
95+
first_sentence = tf.constant("The quick brown fox jumped.")
96+
second_sentence = tf.constant("The fox tripped.")
97+
preprocessor((first_sentence, second_sentence))
98+
99+
# Map a dataset to preprocess a single sentence.
100+
features = tf.constant(
101+
["The quick brown fox jumped.", "Call me Ishmael."]
102+
)
103+
labels = tf.constant([0, 1])
104+
ds = tf.data.Dataset.from_tensor_slices((features, labels))
105+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
106+
107+
# Map a dataset to preprocess sentence pairs.
108+
first_sentences = tf.constant(
109+
["The quick brown fox jumped.", "Call me Ishmael."]
110+
)
111+
second_sentences = tf.constant(
112+
["The fox tripped.", "Oh look, a whale."]
113+
)
114+
labels = tf.constant([1, 1])
115+
ds = tf.data.Dataset.from_tensor_slices(
116+
(
117+
(first_sentences, second_sentences), labels
118+
)
119+
)
120+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
121+
122+
# Map a dataset to preprocess unlabeled sentence pairs.
123+
first_sentences = tf.constant(
124+
["The quick brown fox jumped.", "Call me Ishmael."]
125+
)
126+
second_sentences = tf.constant(
127+
["The fox tripped.", "Oh look, a whale."]
128+
)
129+
ds = tf.data.Dataset.from_tensor_slices((first_sentences, second_sentences))
130+
# Watch out for tf.data's default unpacking of tuples here!
131+
# Best to invoke the `preprocessor` directly in this case.
132+
ds = ds.map(
133+
lambda s1, s2: preprocessor(x=(s1, s2)),
134+
num_parallel_calls=tf.data.AUTOTUNE,
135+
)
136+
```
137+
"""
138+
139+
def __init__(
140+
self,
141+
tokenizer,
142+
sequence_length=512,
143+
truncate="round_robin",
144+
**kwargs,
145+
):
146+
super().__init__(**kwargs)
147+
self._tokenizer = tokenizer
148+
self.packer = MultiSegmentPacker(
149+
start_value=self.tokenizer.cls_token_id,
150+
end_value=self.tokenizer.sep_token_id,
151+
pad_value=self.tokenizer.pad_token_id,
152+
truncate=truncate,
153+
sequence_length=sequence_length,
154+
)
155+
156+
@property
157+
def tokenizer(self):
158+
"""The `keras_nlp.models.AlbertTokenizer` used to tokenize strings."""
159+
return self._tokenizer
160+
161+
def get_config(self):
162+
config = super().get_config()
163+
config.update(
164+
{
165+
"tokenizer": keras.layers.serialize(self.tokenizer),
166+
"sequence_length": self.packer.sequence_length,
167+
"truncate": self.packer.truncate,
168+
}
169+
)
170+
return config
171+
172+
@classmethod
173+
def from_config(cls, config):
174+
if "tokenizer" in config and isinstance(config["tokenizer"], dict):
175+
config["tokenizer"] = keras.layers.deserialize(config["tokenizer"])
176+
return cls(**config)
177+
178+
def call(self, x, y=None, sample_weight=None):
179+
x = convert_inputs_to_list_of_tensor_segments(x)
180+
x = [self.tokenizer(segment) for segment in x]
181+
token_ids, segment_ids = self.packer(x)
182+
x = {
183+
"token_ids": token_ids,
184+
"segment_ids": segment_ids,
185+
"padding_mask": token_ids != self.tokenizer.pad_token_id,
186+
}
187+
return pack_x_y_sample_weight(x, y, sample_weight)
188+
189+
@classproperty
190+
def presets(cls):
191+
return {}
192+
193+
@classmethod
194+
def from_preset(
195+
cls,
196+
preset,
197+
sequence_length=None,
198+
truncate="round_robin",
199+
**kwargs,
200+
):
201+
raise NotImplementedError
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for ALBERT preprocessor layer."""
16+
17+
import io
18+
import os
19+
20+
import sentencepiece
21+
import tensorflow as tf
22+
from absl.testing import parameterized
23+
from tensorflow import keras
24+
25+
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
26+
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
27+
28+
29+
class AlbertPreprocessorTest(tf.test.TestCase, parameterized.TestCase):
30+
def setUp(self):
31+
bytes_io = io.BytesIO()
32+
vocab_data = tf.data.Dataset.from_tensor_slices(
33+
["the quick brown fox", "the earth is round"]
34+
)
35+
sentencepiece.SentencePieceTrainer.train(
36+
sentence_iterator=vocab_data.as_numpy_iterator(),
37+
model_writer=bytes_io,
38+
vocab_size=10,
39+
model_type="WORD",
40+
pad_id=0,
41+
unk_id=1,
42+
bos_id=2,
43+
eos_id=3,
44+
pad_piece="<pad>",
45+
unk_piece="<unk>",
46+
bos_piece="[CLS]",
47+
eos_piece="[SEP]",
48+
)
49+
self.proto = bytes_io.getvalue()
50+
51+
self.preprocessor = AlbertPreprocessor(
52+
tokenizer=AlbertTokenizer(proto=self.proto),
53+
sequence_length=12,
54+
)
55+
56+
def test_tokenize_strings(self):
57+
input_data = "the quick brown fox"
58+
output = self.preprocessor(input_data)
59+
self.assertAllEqual(
60+
output["token_ids"], [2, 4, 9, 5, 7, 3, 0, 0, 0, 0, 0, 0]
61+
)
62+
self.assertAllEqual(
63+
output["segment_ids"], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
64+
)
65+
self.assertAllEqual(
66+
output["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]
67+
)
68+
69+
def test_tokenize_list_of_strings(self):
70+
# We should handle a list of strings as as batch.
71+
input_data = ["the quick brown fox"] * 4
72+
output = self.preprocessor(input_data)
73+
self.assertAllEqual(
74+
output["token_ids"],
75+
[[2, 4, 9, 5, 7, 3, 0, 0, 0, 0, 0, 0]] * 4,
76+
)
77+
self.assertAllEqual(
78+
output["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4
79+
)
80+
self.assertAllEqual(
81+
output["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]] * 4
82+
)
83+
84+
def test_tokenize_labeled_batch(self):
85+
x = tf.constant(["the quick brown fox"] * 4)
86+
y = tf.constant([1] * 4)
87+
sw = tf.constant([1.0] * 4)
88+
x_out, y_out, sw_out = self.preprocessor(x, y, sw)
89+
self.assertAllEqual(
90+
x_out["token_ids"],
91+
[[2, 4, 9, 5, 7, 3, 0, 0, 0, 0, 0, 0]] * 4,
92+
)
93+
self.assertAllEqual(
94+
x_out["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4
95+
)
96+
self.assertAllEqual(
97+
x_out["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]] * 4
98+
)
99+
self.assertAllEqual(y_out, y)
100+
self.assertAllEqual(sw_out, sw)
101+
102+
def test_tokenize_labeled_dataset(self):
103+
x = tf.constant(["the quick brown fox"] * 4)
104+
y = tf.constant([1] * 4)
105+
sw = tf.constant([1.0] * 4)
106+
ds = tf.data.Dataset.from_tensor_slices((x, y, sw))
107+
ds = ds.map(self.preprocessor)
108+
x_out, y_out, sw_out = ds.batch(4).take(1).get_single_element()
109+
self.assertAllEqual(
110+
x_out["token_ids"],
111+
[[2, 4, 9, 5, 7, 3, 0, 0, 0, 0, 0, 0]] * 4,
112+
)
113+
self.assertAllEqual(
114+
x_out["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4
115+
)
116+
self.assertAllEqual(
117+
x_out["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]] * 4
118+
)
119+
self.assertAllEqual(y_out, y)
120+
self.assertAllEqual(sw_out, sw)
121+
122+
def test_tokenize_multiple_sentences(self):
123+
sentence_one = tf.constant("the quick brown fox")
124+
sentence_two = tf.constant("the earth")
125+
output = self.preprocessor((sentence_one, sentence_two))
126+
self.assertAllEqual(
127+
output["token_ids"],
128+
[2, 4, 9, 5, 7, 3, 4, 6, 3, 0, 0, 0],
129+
)
130+
self.assertAllEqual(
131+
output["segment_ids"], [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0]
132+
)
133+
self.assertAllEqual(
134+
output["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
135+
)
136+
137+
def test_tokenize_multiple_batched_sentences(self):
138+
sentence_one = tf.constant(["the quick brown fox"] * 4)
139+
sentence_two = tf.constant(["the earth"] * 4)
140+
# The first tuple or list is always interpreted as an enumeration of
141+
# separate sequences to concatenate.
142+
output = self.preprocessor((sentence_one, sentence_two))
143+
self.assertAllEqual(
144+
output["token_ids"],
145+
[[2, 4, 9, 5, 7, 3, 4, 6, 3, 0, 0, 0]] * 4,
146+
)
147+
self.assertAllEqual(
148+
output["segment_ids"], [[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0]] * 4
149+
)
150+
self.assertAllEqual(
151+
output["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4
152+
)
153+
154+
def test_errors_for_2d_list_input(self):
155+
ambiguous_input = [["one", "two"], ["three", "four"]]
156+
with self.assertRaises(ValueError):
157+
self.preprocessor(ambiguous_input)
158+
159+
@parameterized.named_parameters(
160+
("tf_format", "tf", "model"),
161+
("keras_format", "keras_v3", "model.keras"),
162+
)
163+
def test_saved_model(self, save_format, filename):
164+
input_data = tf.constant(["the quick brown fox"])
165+
inputs = keras.Input(dtype="string", shape=())
166+
outputs = self.preprocessor(inputs)
167+
model = keras.Model(inputs, outputs)
168+
path = os.path.join(self.get_temp_dir(), filename)
169+
model.save(path, save_format=save_format)
170+
restored_model = keras.models.load_model(path)
171+
self.assertAllEqual(
172+
model(input_data)["token_ids"],
173+
restored_model(input_data)["token_ids"],
174+
)

0 commit comments

Comments
 (0)