Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 28 additions & 146 deletions keras_nlp/models/bert/bert_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""BERT classification model."""

import copy
import os

from tensorflow import keras

Expand All @@ -23,15 +22,15 @@
from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor
from keras_nlp.models.bert.bert_presets import backbone_presets
from keras_nlp.models.bert.bert_presets import classifier_presets
from keras_nlp.utils.pipeline_model import PipelineModel
from keras_nlp.models.task import Task
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.python_utils import format_docstring

PRESET_NAMES = ", ".join(list(backbone_presets) + list(classifier_presets))


@keras.utils.register_keras_serializable(package="keras_nlp")
class BertClassifier(PipelineModel):
class BertClassifier(Task):
"""An end-to-end BERT model for classification tasks

This model attaches a classification head to a `keras_nlp.model.BertBackbone`
Expand Down Expand Up @@ -124,164 +123,47 @@ def __init__(
self._backbone = backbone
self._preprocessor = preprocessor
self.num_classes = num_classes

def preprocess_samples(self, x, y=None, sample_weight=None):
return self.preprocessor(x, y=y, sample_weight=sample_weight)

@property
def backbone(self):
"""A `keras_nlp.models.BertBackbone` instance providing the encoder
submodel.
"""
return self._backbone

@property
def preprocessor(self):
"""A `keras_nlp.models.BertPreprocessor` for preprocessing inputs."""
return self._preprocessor
self.dropout = dropout

def get_config(self):
return {
"backbone": keras.layers.serialize(self.backbone),
"preprocessor": keras.layers.serialize(self.preprocessor),
"num_classes": self.num_classes,
"name": self.name,
"trainable": self.trainable,
}
config = super().get_config()
config.update(
{
"num_classes": self.num_classes,
"dropout": self.dropout,
"name": self.name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could "name" and "trainable" be added to the Task get_config? I guess it would be awkward to look for them in kwargs to pass to super().get_config()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make this change. I was going to do the same for backbones.

Having name and trainable handle in a base class is actually the norm, so it will improve readability here.

"trainable": self.trainable,
}
)
return config

@classmethod
def from_config(cls, config):
if "backbone" in config and isinstance(config["backbone"], dict):
config["backbone"] = keras.layers.deserialize(config["backbone"])
if "preprocessor" in config and isinstance(
config["preprocessor"], dict
):
config["preprocessor"] = keras.layers.deserialize(
config["preprocessor"]
)
return cls(**config)
@classproperty
def backbone_cls(cls):
return BertBackbone

@classproperty
def preprocessor_cls(cls):
return BertPreprocessor

@classproperty
def presets(cls):
return copy.deepcopy({**backbone_presets, **classifier_presets})

@classmethod
@format_docstring(names=PRESET_NAMES)
def from_preset(
cls,
preset,
load_weights=True,
**kwargs,
):
"""Create a classification model from a preset architecture and weights.

By default, this method will automatically create a `preprocessor`
layer to preprocess raw inputs during `fit()`, `predict()`, and
`evaluate()`. If you would like to disable this behavior, pass
`preprocessor=None`.

Args:
preset: string. Must be one of {{names}}.
load_weights: Whether to load pre-trained weights into model.
Defaults to `True`.

Examples:

Raw string inputs.
```python
# Create a dataset with raw string features in an `(x, y)` format.
features = ["The quick brown fox jumped.", "I forgot my homework."]
labels = [0, 3]

# Create a BertClassifier and fit your data.
classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=4,
)
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
return super().from_preset(
preset=preset, load_weights=load_weights, **kwargs
)
classifier.fit(x=features, y=labels, batch_size=2)
```

Raw string inputs with customized preprocessing.
```python
# Create a dataset with raw string features in an `(x, y)` format.
features = ["The quick brown fox jumped.", "I forgot my homework."]
labels = [0, 3]

# Use a shorter sequence length.
preprocessor = keras_nlp.models.BertPreprocessor.from_preset(
"bert_base_en_uncased",
sequence_length=128,
)

# Create a BertClassifier and fit your data.
classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=4,
preprocessor=preprocessor,
)
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
classifier.fit(x=features, y=labels, batch_size=2)
```

Preprocessed inputs.
```python
# Create a dataset with preprocessed features in an `(x, y)` format.
preprocessed_features = {
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
"segment_ids": tf.constant(
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
),
"padding_mask": tf.constant(
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
),
}
labels = [0, 3]

# Create a BERT classifier and fit your data.
classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=4,
preprocessor=None,
)
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
```
"""
if preset not in cls.presets:
raise ValueError(
"`preset` must be one of "
f"""{", ".join(cls.presets)}. Received: {preset}."""
)

if "preprocessor" not in kwargs:
kwargs["preprocessor"] = BertPreprocessor.from_preset(preset)

# Check if preset is backbone-only model
if preset in BertBackbone.presets:
backbone = BertBackbone.from_preset(preset, load_weights)
return cls(backbone, **kwargs)

# Otherwise must be one of class presets
metadata = cls.presets[preset]
config = metadata["config"]
model = cls.from_config({**config, **kwargs})

if not load_weights:
return model

weights = keras.utils.get_file(
"model.h5",
metadata["weights_url"],
cache_subdir=os.path.join("models", preset),
file_hash=metadata["weights_hash"],
)

model.load_weights(weights)
return model
BertClassifier.from_preset.__func__.__doc__ = Task.from_preset.__doc__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can switch to the approach in #654 now

format_docstring(
model_task_name=BertClassifier.__name__,
example_preset_name="bert_base_en_uncased",
preset_names=PRESET_NAMES,
)(BertClassifier.from_preset.__func__)
Loading