1414"""BERT classification model."""
1515
1616import copy
17- import os
1817
1918from tensorflow import keras
2019
2322from keras_nlp .models .bert .bert_preprocessor import BertPreprocessor
2423from keras_nlp .models .bert .bert_presets import backbone_presets
2524from keras_nlp .models .bert .bert_presets import classifier_presets
26- from keras_nlp .utils . pipeline_model import PipelineModel
25+ from keras_nlp .models . task import Task
2726from keras_nlp .utils .python_utils import classproperty
28- from keras_nlp .utils .python_utils import format_docstring
29-
30- PRESET_NAMES = ", " .join (list (backbone_presets ) + list (classifier_presets ))
3127
3228
3329@keras .utils .register_keras_serializable (package = "keras_nlp" )
34- class BertClassifier (PipelineModel ):
30+ class BertClassifier (Task ):
3531 """An end-to-end BERT model for classification tasks
3632
3733 This model attaches a classification head to a `keras_nlp.model.BertBackbone`
@@ -58,8 +54,9 @@ class BertClassifier(PipelineModel):
5854
5955 Examples:
6056
57+ Example usage.
6158 ```python
62- # Call classifier on the inputs.
59+ # Define the preprocessed inputs.
6360 preprocessed_features = {
6461 "token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
6562 "segment_ids": tf.constant(
@@ -95,6 +92,73 @@ class BertClassifier(PipelineModel):
9592 # Access backbone programatically (e.g., to change `trainable`)
9693 classifier.backbone.trainable = False
9794 ```
95+
96+ Raw string inputs.
97+ ```python
98+ # Create a dataset with raw string features in an `(x, y)` format.
99+ features = ["The quick brown fox jumped.", "I forgot my homework."]
100+ labels = [0, 3]
101+
102+ # Create a BertClassifier and fit your data.
103+ classifier = keras_nlp.models.BertClassifier.from_preset(
104+ "bert_base_en_uncased",
105+ num_classes=4,
106+ )
107+ classifier.compile(
108+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
109+ )
110+ classifier.fit(x=features, y=labels, batch_size=2)
111+ ```
112+
113+ Raw string inputs with customized preprocessing.
114+ ```python
115+ # Create a dataset with raw string features in an `(x, y)` format.
116+ features = ["The quick brown fox jumped.", "I forgot my homework."]
117+ labels = [0, 3]
118+
119+ # Use a shorter sequence length.
120+ preprocessor = keras_nlp.models.BertPreprocessor.from_preset(
121+ "bert_base_en_uncased",
122+ sequence_length=128,
123+ )
124+
125+ # Create a BertClassifier and fit your data.
126+ classifier = keras_nlp.models.BertClassifier.from_preset(
127+ "bert_base_en_uncased",
128+ num_classes=4,
129+ preprocessor=preprocessor,
130+ )
131+ classifier.compile(
132+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
133+ )
134+ classifier.fit(x=features, y=labels, batch_size=2)
135+ ```
136+
137+ Preprocessed inputs.
138+ ```python
139+ # Create a dataset with preprocessed features in an `(x, y)` format.
140+ preprocessed_features = {
141+ "token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
142+ "segment_ids": tf.constant(
143+ [[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
144+ ),
145+ "padding_mask": tf.constant(
146+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
147+ ),
148+ }
149+ labels = [0, 3]
150+
151+ # Create a BERT classifier and fit your data.
152+ classifier = keras_nlp.models.BertClassifier.from_preset(
153+ "bert_base_en_uncased",
154+ num_classes=4,
155+ preprocessor=None,
156+ )
157+ classifier.compile(
158+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
159+ )
160+ classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
161+ ```
98162 """
99163
100164 def __init__ (
@@ -124,164 +188,26 @@ def __init__(
124188 self ._backbone = backbone
125189 self ._preprocessor = preprocessor
126190 self .num_classes = num_classes
127-
128- def preprocess_samples (self , x , y = None , sample_weight = None ):
129- return self .preprocessor (x , y = y , sample_weight = sample_weight )
130-
131- @property
132- def backbone (self ):
133- """A `keras_nlp.models.BertBackbone` instance providing the encoder
134- submodel.
135- """
136- return self ._backbone
137-
138- @property
139- def preprocessor (self ):
140- """A `keras_nlp.models.BertPreprocessor` for preprocessing inputs."""
141- return self ._preprocessor
191+ self .dropout = dropout
142192
143193 def get_config (self ):
144- return {
145- "backbone" : keras .layers .serialize (self .backbone ),
146- "preprocessor" : keras .layers .serialize (self .preprocessor ),
147- "num_classes" : self .num_classes ,
148- "name" : self .name ,
149- "trainable" : self .trainable ,
150- }
194+ config = super ().get_config ()
195+ config .update (
196+ {
197+ "num_classes" : self .num_classes ,
198+ "dropout" : self .dropout ,
199+ }
200+ )
201+ return config
202+
203+ @classproperty
204+ def backbone_cls (cls ):
205+ return BertBackbone
151206
152- @classmethod
153- def from_config (cls , config ):
154- if "backbone" in config and isinstance (config ["backbone" ], dict ):
155- config ["backbone" ] = keras .layers .deserialize (config ["backbone" ])
156- if "preprocessor" in config and isinstance (
157- config ["preprocessor" ], dict
158- ):
159- config ["preprocessor" ] = keras .layers .deserialize (
160- config ["preprocessor" ]
161- )
162- return cls (** config )
207+ @classproperty
208+ def preprocessor_cls (cls ):
209+ return BertPreprocessor
163210
164211 @classproperty
165212 def presets (cls ):
166213 return copy .deepcopy ({** backbone_presets , ** classifier_presets })
167-
168- @classmethod
169- @format_docstring (names = PRESET_NAMES )
170- def from_preset (
171- cls ,
172- preset ,
173- load_weights = True ,
174- ** kwargs ,
175- ):
176- """Create a classification model from a preset architecture and weights.
177-
178- By default, this method will automatically create a `preprocessor`
179- layer to preprocess raw inputs during `fit()`, `predict()`, and
180- `evaluate()`. If you would like to disable this behavior, pass
181- `preprocessor=None`.
182-
183- Args:
184- preset: string. Must be one of {{names}}.
185- load_weights: Whether to load pre-trained weights into model.
186- Defaults to `True`.
187-
188- Examples:
189-
190- Raw string inputs.
191- ```python
192- # Create a dataset with raw string features in an `(x, y)` format.
193- features = ["The quick brown fox jumped.", "I forgot my homework."]
194- labels = [0, 3]
195-
196- # Create a BertClassifier and fit your data.
197- classifier = keras_nlp.models.BertClassifier.from_preset(
198- "bert_base_en_uncased",
199- num_classes=4,
200- )
201- classifier.compile(
202- loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
203- )
204- classifier.fit(x=features, y=labels, batch_size=2)
205- ```
206-
207- Raw string inputs with customized preprocessing.
208- ```python
209- # Create a dataset with raw string features in an `(x, y)` format.
210- features = ["The quick brown fox jumped.", "I forgot my homework."]
211- labels = [0, 3]
212-
213- # Use a shorter sequence length.
214- preprocessor = keras_nlp.models.BertPreprocessor.from_preset(
215- "bert_base_en_uncased",
216- sequence_length=128,
217- )
218-
219- # Create a BertClassifier and fit your data.
220- classifier = keras_nlp.models.BertClassifier.from_preset(
221- "bert_base_en_uncased",
222- num_classes=4,
223- preprocessor=preprocessor,
224- )
225- classifier.compile(
226- loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
227- )
228- classifier.fit(x=features, y=labels, batch_size=2)
229- ```
230-
231- Preprocessed inputs.
232- ```python
233- # Create a dataset with preprocessed features in an `(x, y)` format.
234- preprocessed_features = {
235- "token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
236- "segment_ids": tf.constant(
237- [[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
238- ),
239- "padding_mask": tf.constant(
240- [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
241- ),
242- }
243- labels = [0, 3]
244-
245- # Create a BERT classifier and fit your data.
246- classifier = keras_nlp.models.BertClassifier.from_preset(
247- "bert_base_en_uncased",
248- num_classes=4,
249- preprocessor=None,
250- )
251- classifier.compile(
252- loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
253- )
254- classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
255- ```
256- """
257- if preset not in cls .presets :
258- raise ValueError (
259- "`preset` must be one of "
260- f"""{ ", " .join (cls .presets )} . Received: { preset } ."""
261- )
262-
263- if "preprocessor" not in kwargs :
264- kwargs ["preprocessor" ] = BertPreprocessor .from_preset (preset )
265-
266- # Check if preset is backbone-only model
267- if preset in BertBackbone .presets :
268- backbone = BertBackbone .from_preset (preset , load_weights )
269- return cls (backbone , ** kwargs )
270-
271- # Otherwise must be one of class presets
272- metadata = cls .presets [preset ]
273- config = metadata ["config" ]
274- model = cls .from_config ({** config , ** kwargs })
275-
276- if not load_weights :
277- return model
278-
279- weights = keras .utils .get_file (
280- "model.h5" ,
281- metadata ["weights_url" ],
282- cache_subdir = os .path .join ("models" , preset ),
283- file_hash = metadata ["weights_hash" ],
284- )
285-
286- model .load_weights (weights )
287- return model
0 commit comments