Skip to content

Commit 85514c5

Browse files
Add AlbertClassifier (#668)
* init commit * added classifier, from_preset_method unimplemented yet * running formatting, removing unused imports * incorporating suggested changes * formatting * updating docstrings * fixing errors due to merge * fixing formattinf * Fix test names --------- Co-authored-by: Matt Watson <[email protected]>
1 parent 2ae2d4d commit 85514c5

File tree

3 files changed

+410
-2
lines changed

3 files changed

+410
-2
lines changed
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""ALBERT classification model."""
15+
16+
import copy
17+
18+
from tensorflow import keras
19+
20+
from keras_nlp.models.albert.albert_backbone import AlbertBackbone
21+
from keras_nlp.models.albert.albert_backbone import albert_kernel_initializer
22+
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
23+
from keras_nlp.models.albert.albert_presets import backbone_presets
24+
from keras_nlp.models.task import Task
25+
from keras_nlp.utils.python_utils import classproperty
26+
27+
28+
@keras.utils.register_keras_serializable(package="keras_nlp")
29+
class AlbertClassifier(Task):
30+
"""An end-to-end ALBERT model for classification tasks
31+
32+
This model attaches a classification head to a `keras_nlp.model.AlbertBackbone`
33+
backbone, mapping from the backbone outputs to logit output suitable for
34+
a classification task. For usage of this model with pre-trained weights, see
35+
the `from_preset()` method.
36+
37+
This model can optionally be configured with a `preprocessor` layer, in
38+
which case it will automatically apply preprocessing to raw inputs during
39+
`fit()`, `predict()`, and `evaluate()`. This is done by default when
40+
creating the model with `from_preset()`.
41+
42+
Disclaimer: Pre-trained models are provided on an "as is" basis, without
43+
warranties or conditions of any kind.
44+
45+
Args:
46+
backbone: A `keras_nlp.models.AlertBackbone` instance.
47+
num_classes: int. Number of classes to predict.
48+
dropout: float. The dropout probability value, applied after the dense
49+
layer.
50+
preprocessor: A `keras_nlp.models.AlbertPreprocessor` or `None`. If
51+
`None`, this model will not apply preprocessing, and inputs should
52+
be preprocessed before calling the model.
53+
54+
Examples:
55+
56+
Example usage.
57+
```python
58+
# Define the preprocessed inputs.
59+
preprocessed_features = {
60+
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
61+
"segment_ids": tf.constant(
62+
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
63+
),
64+
"padding_mask": tf.constant(
65+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
66+
),
67+
}
68+
labels = [0, 3]
69+
70+
# Randomly initialize a ALBERT backbone.
71+
backbone = AlbertBackbone(
72+
vocabulary_size=1000,
73+
num_layers=2,
74+
num_heads=2,
75+
embedding_dim=8,
76+
hidden_dim=64,
77+
intermediate_dim=128,
78+
max_sequence_length=128,
79+
name="encoder",
80+
)
81+
82+
# Create a ALBERT classifier and fit your data.
83+
classifier = keras_nlp.models.AlbertClassifier(
84+
backbone,
85+
num_classes=4,
86+
preprocessor=None,
87+
)
88+
classifier.compile(
89+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
90+
)
91+
classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
92+
93+
# Access backbone programatically (e.g., to change `trainable`)
94+
classifier.backbone.trainable = False
95+
96+
Raw string inputs with customized preprocessing.
97+
```python
98+
# Create a dataset with raw string features in an `(x, y)` format.
99+
features = ["The quick brown fox jumped.", "I forgot my homework."]
100+
labels = [0, 3]
101+
102+
# Use a shorter sequence length.
103+
preprocessor = keras_nlp.models.AlbertPreprocessor.from_preset(
104+
"albert_base_en_uncased",
105+
sequence_length=128,
106+
)
107+
108+
# Create a AlbertClassifier and fit your data.
109+
classifier = keras_nlp.models.AlbertClassifier.from_preset(
110+
"albert_base_en_uncased",
111+
num_classes=4,
112+
preprocessor=preprocessor,
113+
)
114+
classifier.compile(
115+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
116+
)
117+
classifier.fit(x=features, y=labels, batch_size=2)
118+
```
119+
120+
Preprocessed inputs.
121+
```python
122+
# Create a dataset with preprocessed features in an `(x, y)` format.
123+
preprocessed_features = {
124+
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
125+
"segment_ids": tf.constant(
126+
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
127+
),
128+
"padding_mask": tf.constant(
129+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
130+
),
131+
}
132+
labels = [0, 3]
133+
134+
# Create a ALBERT classifier and fit your data.
135+
classifier = keras_nlp.models.AlbertClassifier.from_preset(
136+
"albert_base_en_uncased",
137+
num_classes=4,
138+
preprocessor=None,
139+
)
140+
classifier.compile(
141+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
142+
)
143+
classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
144+
```
145+
"""
146+
147+
def __init__(
148+
self,
149+
backbone,
150+
num_classes=2,
151+
dropout=0.1,
152+
preprocessor=None,
153+
**kwargs,
154+
):
155+
inputs = backbone.input
156+
pooled = backbone(inputs)["pooled_output"]
157+
pooled = keras.layers.Dropout(dropout)(pooled)
158+
outputs = keras.layers.Dense(
159+
num_classes,
160+
kernel_initializer=albert_kernel_initializer(),
161+
name="logits",
162+
)(pooled)
163+
# Instantiate using Functional API Model constructor
164+
super().__init__(
165+
inputs=inputs,
166+
outputs=outputs,
167+
include_preprocessing=preprocessor is not None,
168+
**kwargs,
169+
)
170+
# All references to `self` below this line
171+
self._backbone = backbone
172+
self._preprocessor = preprocessor
173+
self.num_classes = num_classes
174+
self.dropout = dropout
175+
176+
def get_config(self):
177+
config = super().get_config()
178+
config.update(
179+
{
180+
"num_classes": self.num_classes,
181+
"dropout": self.dropout,
182+
}
183+
)
184+
185+
return config
186+
187+
@classproperty
188+
def backbone_cls(cls):
189+
return AlbertBackbone
190+
191+
@classproperty
192+
def preprocessor_cls(cls):
193+
return AlbertPreprocessor
194+
195+
@classproperty
196+
def presets(cls):
197+
return copy.deepcopy({**backbone_presets})
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for BERT classification model."""
15+
16+
import io
17+
import os
18+
19+
import sentencepiece
20+
import tensorflow as tf
21+
from absl.testing import parameterized
22+
from tensorflow import keras
23+
24+
from keras_nlp.models.albert.albert_backbone import AlbertBackbone
25+
from keras_nlp.models.albert.albert_classifier import AlbertClassifier
26+
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
27+
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
28+
29+
30+
class AlbertClassifierTest(tf.test.TestCase, parameterized.TestCase):
31+
def setUp(self):
32+
self.backbone = AlbertBackbone(
33+
vocabulary_size=1000,
34+
num_layers=2,
35+
num_heads=2,
36+
embedding_dim=8,
37+
hidden_dim=64,
38+
intermediate_dim=128,
39+
max_sequence_length=128,
40+
name="encoder",
41+
)
42+
43+
bytes_io = io.BytesIO()
44+
vocab_data = tf.data.Dataset.from_tensor_slices(
45+
["the quick brown fox", "the earth is round"]
46+
)
47+
sentencepiece.SentencePieceTrainer.train(
48+
sentence_iterator=vocab_data.as_numpy_iterator(),
49+
model_writer=bytes_io,
50+
vocab_size=10,
51+
model_type="WORD",
52+
pad_id=0,
53+
unk_id=1,
54+
bos_id=2,
55+
eos_id=3,
56+
pad_piece="<pad>",
57+
unk_piece="<unk>",
58+
bos_piece="[CLS]",
59+
eos_piece="[SEP]",
60+
)
61+
self.proto = bytes_io.getvalue()
62+
63+
tokenizer = AlbertTokenizer(proto=self.proto)
64+
65+
self.preprocessor = AlbertPreprocessor(
66+
tokenizer=tokenizer,
67+
sequence_length=8,
68+
)
69+
self.classifier = AlbertClassifier(
70+
self.backbone,
71+
4,
72+
preprocessor=self.preprocessor,
73+
)
74+
self.classifier_no_preprocessing = AlbertClassifier(
75+
self.backbone,
76+
4,
77+
preprocessor=None,
78+
)
79+
80+
self.raw_batch = tf.constant(
81+
[
82+
"the quick brown fox.",
83+
"the slow brown fox.",
84+
"the smelly brown fox.",
85+
"the old brown fox.",
86+
]
87+
)
88+
self.preprocessed_batch = self.preprocessor(self.raw_batch)
89+
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
90+
(self.raw_batch, tf.ones((4,)))
91+
).batch(2)
92+
self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor)
93+
94+
def test_valid_call_classifier(self):
95+
self.classifier(self.preprocessed_batch)
96+
97+
@parameterized.named_parameters(
98+
("jit_compile_false", False), ("jit_compile_true", True)
99+
)
100+
def test_albert_classifier_predict(self, jit_compile):
101+
self.classifier.compile(jit_compile=jit_compile)
102+
self.classifier.predict(self.raw_batch)
103+
104+
@parameterized.named_parameters(
105+
("jit_compile_false", False), ("jit_compile_true", True)
106+
)
107+
def test_albert_classifier_predict_no_preprocessing(self, jit_compile):
108+
self.classifier_no_preprocessing.compile(jit_compile=jit_compile)
109+
self.classifier_no_preprocessing.predict(self.preprocessed_batch)
110+
111+
@parameterized.named_parameters(
112+
("jit_compile_false", False), ("jit_compile_true", True)
113+
)
114+
def test_albert_classifier_fit(self, jit_compile):
115+
self.classifier.compile(
116+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
117+
jit_compile=jit_compile,
118+
)
119+
self.classifier.fit(self.raw_dataset)
120+
121+
@parameterized.named_parameters(
122+
("jit_compile_false", False), ("jit_compile_true", True)
123+
)
124+
def test_albert_classifier_fit_no_preprocessing(self, jit_compile):
125+
self.classifier_no_preprocessing.compile(
126+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
127+
jit_compile=jit_compile,
128+
)
129+
self.classifier_no_preprocessing.fit(self.preprocessed_dataset)
130+
131+
@parameterized.named_parameters(
132+
("tf_format", "tf", "model"),
133+
("keras_format", "keras_v3", "model.keras"),
134+
)
135+
def test_saved_model(self, save_format, filename):
136+
model_output = self.classifier.predict(self.raw_batch)
137+
save_path = os.path.join(self.get_temp_dir(), filename)
138+
self.classifier.save(save_path, save_format=save_format)
139+
restored_model = keras.models.load_model(save_path)
140+
141+
# Check we got the real object back.
142+
self.assertIsInstance(restored_model, AlbertClassifier)
143+
144+
# Check that output matches.
145+
restored_output = restored_model.predict(self.raw_batch)
146+
self.assertAllClose(model_output, restored_output)

0 commit comments

Comments
 (0)