3030
3131@keras_nlp_export ("keras_nlp.models.BertClassifier" )
3232class BertClassifier (Task ):
33- """An end-to-end BERT model for classification tasks
33+ """An end-to-end BERT model for classification tasks.
3434
35- This model attaches a classification head to a `keras_nlp.model.BertBackbone`
36- backbone , mapping from the backbone outputs to logit output suitable for
37- a classification task. For usage of this model with pre-trained weights, see
38- the `from_preset()` method .
35+ This model attaches a classification head to a
36+ `keras_nlp.model.BertBackbone` instance , mapping from the backbone outputs
37+ to logits suitable for a classification task. For usage of this model with
38+ pre-trained weights, use the `from_preset()` constructor .
3939
4040 This model can optionally be configured with a `preprocessor` layer, in
4141 which case it will automatically apply preprocessing to raw inputs during
@@ -56,90 +56,34 @@ class BertClassifier(Task):
5656
5757 Examples:
5858
59- Example usage .
59+ Raw string data .
6060 ```python
61- # Define the preprocessed inputs.
62- preprocessed_features = {
63- "token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
64- "segment_ids": tf.constant(
65- [[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
66- ),
67- "padding_mask": tf.constant(
68- [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
69- ),
70- }
71- labels = [0, 3]
72-
73- # Randomly initialize a BERT backbone.
74- backbone = keras_nlp.models.BertBackbone(
75- vocabulary_size=30552,
76- num_layers=12,
77- num_heads=12,
78- hidden_dim=768,
79- intermediate_dim=3072,
80- max_sequence_length=12
81- )
82-
83- # Create a BERT classifier and fit your data.
84- classifier = keras_nlp.models.BertClassifier(
85- backbone,
86- num_classes=4,
87- preprocessor=None,
88- )
89- classifier.compile(
90- loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
91- )
92- classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
93-
94- # Access backbone programatically (e.g., to change `trainable`)
95- classifier.backbone.trainable = False
96- ```
97-
98- Raw string inputs.
99- ```python
100- # Create a dataset with raw string features in an `(x, y)` format.
10161 features = ["The quick brown fox jumped.", "I forgot my homework."]
10262 labels = [0, 3]
10363
104- # Create a BertClassifier and fit your data .
64+ # Pretrained classifier .
10565 classifier = keras_nlp.models.BertClassifier.from_preset(
10666 "bert_base_en_uncased",
10767 num_classes=4,
10868 )
109- classifier.compile(
110- loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
111- )
11269 classifier.fit(x=features, y=labels, batch_size=2)
113- ```
114-
115- Raw string inputs with customized preprocessing.
116- ```python
117- # Create a dataset with raw string features in an `(x, y)` format.
118- features = ["The quick brown fox jumped.", "I forgot my homework."]
119- labels = [0, 3]
120-
121- # Use a shorter sequence length.
122- preprocessor = keras_nlp.models.BertPreprocessor.from_preset(
123- "bert_base_en_uncased",
124- sequence_length=128,
125- )
70+ classifier.predict(x=features, batch_size=2)
12671
127- # Create a BertClassifier and fit your data.
128- classifier = keras_nlp.models.BertClassifier.from_preset(
129- "bert_base_en_uncased",
130- num_classes=4,
131- preprocessor=preprocessor,
132- )
72+ # Re-compile (e.g., with a new learning rate).
13373 classifier.compile(
13474 loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
75+ optimizer=keras.optimizers.Adam(5e-5),
76+ jit_compile=True,
13577 )
78+ # Access backbone programatically (e.g., to change `trainable`).
79+ classifier.backbone.trainable = False
80+ # Fit again.
13681 classifier.fit(x=features, y=labels, batch_size=2)
13782 ```
13883
139- Preprocessed inputs .
84+ Preprocessed integer data .
14085 ```python
141- # Create a dataset with preprocessed features in an `(x, y)` format.
142- preprocessed_features = {
86+ features = {
14387 "token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
14488 "segment_ids": tf.constant(
14589 [[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
@@ -150,16 +94,43 @@ class BertClassifier(Task):
15094 }
15195 labels = [0, 3]
15296
153- # Create a BERT classifier and fit your data .
97+ # Pretrained classifier without preprocessing .
15498 classifier = keras_nlp.models.BertClassifier.from_preset(
15599 "bert_base_en_uncased",
156100 num_classes=4,
157101 preprocessor=None,
158102 )
159- classifier.compile(
160- loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
103+ classifier.fit(x=features, y=labels, batch_size=2)
104+ ```
105+
106+ Custom backbone and vocabulary.
107+ ```python
108+ features = ["The quick brown fox jumped.", "I forgot my homework."]
109+ labels = [0, 3]
110+
111+ vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
112+ vocab += ["The", "quick", "brown", "fox", "jumped", "."]
113+ tokenizer = keras_nlp.models.BertTokenizer(
114+ vocabulary=vocab,
115+ )
116+ preprocessor = keras_nlp.models.BertPreprocessor(
117+ tokenizer=tokenizer,
118+ sequence_length=128,
161119 )
162- classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
120+ backbone = keras_nlp.models.BertBackbone(
121+ vocabulary_size=30552,
122+ num_layers=4,
123+ num_heads=4,
124+ hidden_dim=256,
125+ intermediate_dim=512,
126+ max_sequence_length=128,
127+ )
128+ classifier = keras_nlp.models.BertClassifier(
129+ backbone=backbone,
130+ preprocessor=preprocessor,
131+ num_classes=4,
132+ )
133+ classifier.fit(x=features, y=labels, batch_size=2)
163134 ```
164135 """
165136
0 commit comments