Skip to content

Commit 1cb93f2

Browse files
authored
Add Base Task Class (#671)
* Add Base Task Class * Minor fix * Fixes * Minor edit * Address comments
1 parent 2f6e398 commit 1cb93f2

File tree

6 files changed

+588
-717
lines changed

6 files changed

+588
-717
lines changed

keras_nlp/models/bert/bert_classifier.py

Lines changed: 87 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"""BERT classification model."""
1515

1616
import copy
17-
import os
1817

1918
from tensorflow import keras
2019

@@ -23,15 +22,12 @@
2322
from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor
2423
from keras_nlp.models.bert.bert_presets import backbone_presets
2524
from keras_nlp.models.bert.bert_presets import classifier_presets
26-
from keras_nlp.utils.pipeline_model import PipelineModel
25+
from keras_nlp.models.task import Task
2726
from keras_nlp.utils.python_utils import classproperty
28-
from keras_nlp.utils.python_utils import format_docstring
29-
30-
PRESET_NAMES = ", ".join(list(backbone_presets) + list(classifier_presets))
3127

3228

3329
@keras.utils.register_keras_serializable(package="keras_nlp")
34-
class BertClassifier(PipelineModel):
30+
class BertClassifier(Task):
3531
"""An end-to-end BERT model for classification tasks
3632
3733
This model attaches a classification head to a `keras_nlp.model.BertBackbone`
@@ -58,8 +54,9 @@ class BertClassifier(PipelineModel):
5854
5955
Examples:
6056
57+
Example usage.
6158
```python
62-
# Call classifier on the inputs.
59+
# Define the preprocessed inputs.
6360
preprocessed_features = {
6461
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
6562
"segment_ids": tf.constant(
@@ -95,6 +92,73 @@ class BertClassifier(PipelineModel):
9592
# Access backbone programatically (e.g., to change `trainable`)
9693
classifier.backbone.trainable = False
9794
```
95+
96+
Raw string inputs.
97+
```python
98+
# Create a dataset with raw string features in an `(x, y)` format.
99+
features = ["The quick brown fox jumped.", "I forgot my homework."]
100+
labels = [0, 3]
101+
102+
# Create a BertClassifier and fit your data.
103+
classifier = keras_nlp.models.BertClassifier.from_preset(
104+
"bert_base_en_uncased",
105+
num_classes=4,
106+
)
107+
classifier.compile(
108+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
109+
)
110+
classifier.fit(x=features, y=labels, batch_size=2)
111+
```
112+
113+
Raw string inputs with customized preprocessing.
114+
```python
115+
# Create a dataset with raw string features in an `(x, y)` format.
116+
features = ["The quick brown fox jumped.", "I forgot my homework."]
117+
labels = [0, 3]
118+
119+
# Use a shorter sequence length.
120+
preprocessor = keras_nlp.models.BertPreprocessor.from_preset(
121+
"bert_base_en_uncased",
122+
sequence_length=128,
123+
)
124+
125+
# Create a BertClassifier and fit your data.
126+
classifier = keras_nlp.models.BertClassifier.from_preset(
127+
"bert_base_en_uncased",
128+
num_classes=4,
129+
preprocessor=preprocessor,
130+
)
131+
classifier.compile(
132+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
133+
)
134+
classifier.fit(x=features, y=labels, batch_size=2)
135+
```
136+
137+
Preprocessed inputs.
138+
```python
139+
# Create a dataset with preprocessed features in an `(x, y)` format.
140+
preprocessed_features = {
141+
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
142+
"segment_ids": tf.constant(
143+
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
144+
),
145+
"padding_mask": tf.constant(
146+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
147+
),
148+
}
149+
labels = [0, 3]
150+
151+
# Create a BERT classifier and fit your data.
152+
classifier = keras_nlp.models.BertClassifier.from_preset(
153+
"bert_base_en_uncased",
154+
num_classes=4,
155+
preprocessor=None,
156+
)
157+
classifier.compile(
158+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
159+
)
160+
classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
161+
```
98162
"""
99163

100164
def __init__(
@@ -124,164 +188,26 @@ def __init__(
124188
self._backbone = backbone
125189
self._preprocessor = preprocessor
126190
self.num_classes = num_classes
127-
128-
def preprocess_samples(self, x, y=None, sample_weight=None):
129-
return self.preprocessor(x, y=y, sample_weight=sample_weight)
130-
131-
@property
132-
def backbone(self):
133-
"""A `keras_nlp.models.BertBackbone` instance providing the encoder
134-
submodel.
135-
"""
136-
return self._backbone
137-
138-
@property
139-
def preprocessor(self):
140-
"""A `keras_nlp.models.BertPreprocessor` for preprocessing inputs."""
141-
return self._preprocessor
191+
self.dropout = dropout
142192

143193
def get_config(self):
144-
return {
145-
"backbone": keras.layers.serialize(self.backbone),
146-
"preprocessor": keras.layers.serialize(self.preprocessor),
147-
"num_classes": self.num_classes,
148-
"name": self.name,
149-
"trainable": self.trainable,
150-
}
194+
config = super().get_config()
195+
config.update(
196+
{
197+
"num_classes": self.num_classes,
198+
"dropout": self.dropout,
199+
}
200+
)
201+
return config
202+
203+
@classproperty
204+
def backbone_cls(cls):
205+
return BertBackbone
151206

152-
@classmethod
153-
def from_config(cls, config):
154-
if "backbone" in config and isinstance(config["backbone"], dict):
155-
config["backbone"] = keras.layers.deserialize(config["backbone"])
156-
if "preprocessor" in config and isinstance(
157-
config["preprocessor"], dict
158-
):
159-
config["preprocessor"] = keras.layers.deserialize(
160-
config["preprocessor"]
161-
)
162-
return cls(**config)
207+
@classproperty
208+
def preprocessor_cls(cls):
209+
return BertPreprocessor
163210

164211
@classproperty
165212
def presets(cls):
166213
return copy.deepcopy({**backbone_presets, **classifier_presets})
167-
168-
@classmethod
169-
@format_docstring(names=PRESET_NAMES)
170-
def from_preset(
171-
cls,
172-
preset,
173-
load_weights=True,
174-
**kwargs,
175-
):
176-
"""Create a classification model from a preset architecture and weights.
177-
178-
By default, this method will automatically create a `preprocessor`
179-
layer to preprocess raw inputs during `fit()`, `predict()`, and
180-
`evaluate()`. If you would like to disable this behavior, pass
181-
`preprocessor=None`.
182-
183-
Args:
184-
preset: string. Must be one of {{names}}.
185-
load_weights: Whether to load pre-trained weights into model.
186-
Defaults to `True`.
187-
188-
Examples:
189-
190-
Raw string inputs.
191-
```python
192-
# Create a dataset with raw string features in an `(x, y)` format.
193-
features = ["The quick brown fox jumped.", "I forgot my homework."]
194-
labels = [0, 3]
195-
196-
# Create a BertClassifier and fit your data.
197-
classifier = keras_nlp.models.BertClassifier.from_preset(
198-
"bert_base_en_uncased",
199-
num_classes=4,
200-
)
201-
classifier.compile(
202-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
203-
)
204-
classifier.fit(x=features, y=labels, batch_size=2)
205-
```
206-
207-
Raw string inputs with customized preprocessing.
208-
```python
209-
# Create a dataset with raw string features in an `(x, y)` format.
210-
features = ["The quick brown fox jumped.", "I forgot my homework."]
211-
labels = [0, 3]
212-
213-
# Use a shorter sequence length.
214-
preprocessor = keras_nlp.models.BertPreprocessor.from_preset(
215-
"bert_base_en_uncased",
216-
sequence_length=128,
217-
)
218-
219-
# Create a BertClassifier and fit your data.
220-
classifier = keras_nlp.models.BertClassifier.from_preset(
221-
"bert_base_en_uncased",
222-
num_classes=4,
223-
preprocessor=preprocessor,
224-
)
225-
classifier.compile(
226-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
227-
)
228-
classifier.fit(x=features, y=labels, batch_size=2)
229-
```
230-
231-
Preprocessed inputs.
232-
```python
233-
# Create a dataset with preprocessed features in an `(x, y)` format.
234-
preprocessed_features = {
235-
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
236-
"segment_ids": tf.constant(
237-
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
238-
),
239-
"padding_mask": tf.constant(
240-
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
241-
),
242-
}
243-
labels = [0, 3]
244-
245-
# Create a BERT classifier and fit your data.
246-
classifier = keras_nlp.models.BertClassifier.from_preset(
247-
"bert_base_en_uncased",
248-
num_classes=4,
249-
preprocessor=None,
250-
)
251-
classifier.compile(
252-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
253-
)
254-
classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
255-
```
256-
"""
257-
if preset not in cls.presets:
258-
raise ValueError(
259-
"`preset` must be one of "
260-
f"""{", ".join(cls.presets)}. Received: {preset}."""
261-
)
262-
263-
if "preprocessor" not in kwargs:
264-
kwargs["preprocessor"] = BertPreprocessor.from_preset(preset)
265-
266-
# Check if preset is backbone-only model
267-
if preset in BertBackbone.presets:
268-
backbone = BertBackbone.from_preset(preset, load_weights)
269-
return cls(backbone, **kwargs)
270-
271-
# Otherwise must be one of class presets
272-
metadata = cls.presets[preset]
273-
config = metadata["config"]
274-
model = cls.from_config({**config, **kwargs})
275-
276-
if not load_weights:
277-
return model
278-
279-
weights = keras.utils.get_file(
280-
"model.h5",
281-
metadata["weights_url"],
282-
cache_subdir=os.path.join("models", preset),
283-
file_hash=metadata["weights_hash"],
284-
)
285-
286-
model.load_weights(weights)
287-
return model

0 commit comments

Comments
 (0)