Skip to content

Commit 7ed1faf

Browse files
authored
Improve bert docstrings (#843)
This reworks the full set of bert models and preprocessing layers for consistency, and in particular, progressive disclosure of complexity. In our example blocks, we will always lead with the simplest usages (e.g. from_preset()), and show more complex usages lower down.
1 parent a7f0311 commit 7ed1faf

File tree

6 files changed

+224
-234
lines changed

6 files changed

+224
-234
lines changed

keras_nlp/models/bert/bert_backbone.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,17 @@ def bert_kernel_initializer(stddev=0.02):
3333

3434
@keras_nlp_export("keras_nlp.models.BertBackbone")
3535
class BertBackbone(Backbone):
36-
"""BERT encoder network.
36+
"""A BERT encoder network.
3737
3838
This class implements a bi-directional Transformer-based encoder as
3939
described in ["BERT: Pre-training of Deep Bidirectional Transformers for
4040
Language Understanding"](https://arxiv.org/abs/1810.04805). It includes the
4141
embedding lookups and transformer layers, but not the masked language model
4242
or next sentence prediction heads.
4343
44-
The default constructor gives a fully customizable, randomly initialized BERT
45-
encoder with any number of layers, heads, and embedding dimensions. To load
46-
preset architectures and weights, use the `from_preset` constructor.
44+
The default constructor gives a fully customizable, randomly initialized
45+
BERT encoder with any number of layers, heads, and embedding dimensions. To
46+
load preset architectures and weights, use the `from_preset()` constructor.
4747
4848
Disclaimer: Pre-trained models are provided on an "as is" basis, without
4949
warranties or conditions of any kind.
@@ -76,20 +76,20 @@ class BertBackbone(Backbone):
7676
),
7777
}
7878
79-
# Pretrained BERT encoder
79+
# Pretrained BERT encoder.
8080
model = keras_nlp.models.BertBackbone.from_preset("bert_base_en_uncased")
81-
output = model(input_data)
81+
model(input_data)
8282
83-
# Randomly initialized BERT encoder with a custom config
83+
# Randomly initialized BERT encoder with a custom config.
8484
model = keras_nlp.models.BertBackbone(
8585
vocabulary_size=30552,
86-
num_layers=12,
87-
num_heads=12,
88-
hidden_dim=768,
89-
intermediate_dim=3072,
90-
max_sequence_length=12,
86+
num_layers=4,
87+
num_heads=4,
88+
hidden_dim=256,
89+
intermediate_dim=512,
90+
max_sequence_length=128,
9191
)
92-
output = model(input_data)
92+
model(input_data)
9393
```
9494
"""
9595

keras_nlp/models/bert/bert_classifier.py

Lines changed: 47 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@
3030

3131
@keras_nlp_export("keras_nlp.models.BertClassifier")
3232
class BertClassifier(Task):
33-
"""An end-to-end BERT model for classification tasks
33+
"""An end-to-end BERT model for classification tasks.
3434
35-
This model attaches a classification head to a `keras_nlp.model.BertBackbone`
36-
backbone, mapping from the backbone outputs to logit output suitable for
37-
a classification task. For usage of this model with pre-trained weights, see
38-
the `from_preset()` method.
35+
This model attaches a classification head to a
36+
`keras_nlp.model.BertBackbone` instance, mapping from the backbone outputs
37+
to logits suitable for a classification task. For usage of this model with
38+
pre-trained weights, use the `from_preset()` constructor.
3939
4040
This model can optionally be configured with a `preprocessor` layer, in
4141
which case it will automatically apply preprocessing to raw inputs during
@@ -56,90 +56,34 @@ class BertClassifier(Task):
5656
5757
Examples:
5858
59-
Example usage.
59+
Raw string data.
6060
```python
61-
# Define the preprocessed inputs.
62-
preprocessed_features = {
63-
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
64-
"segment_ids": tf.constant(
65-
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
66-
),
67-
"padding_mask": tf.constant(
68-
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
69-
),
70-
}
71-
labels = [0, 3]
72-
73-
# Randomly initialize a BERT backbone.
74-
backbone = keras_nlp.models.BertBackbone(
75-
vocabulary_size=30552,
76-
num_layers=12,
77-
num_heads=12,
78-
hidden_dim=768,
79-
intermediate_dim=3072,
80-
max_sequence_length=12
81-
)
82-
83-
# Create a BERT classifier and fit your data.
84-
classifier = keras_nlp.models.BertClassifier(
85-
backbone,
86-
num_classes=4,
87-
preprocessor=None,
88-
)
89-
classifier.compile(
90-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
91-
)
92-
classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
93-
94-
# Access backbone programatically (e.g., to change `trainable`)
95-
classifier.backbone.trainable = False
96-
```
97-
98-
Raw string inputs.
99-
```python
100-
# Create a dataset with raw string features in an `(x, y)` format.
10161
features = ["The quick brown fox jumped.", "I forgot my homework."]
10262
labels = [0, 3]
10363
104-
# Create a BertClassifier and fit your data.
64+
# Pretrained classifier.
10565
classifier = keras_nlp.models.BertClassifier.from_preset(
10666
"bert_base_en_uncased",
10767
num_classes=4,
10868
)
109-
classifier.compile(
110-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
111-
)
11269
classifier.fit(x=features, y=labels, batch_size=2)
113-
```
114-
115-
Raw string inputs with customized preprocessing.
116-
```python
117-
# Create a dataset with raw string features in an `(x, y)` format.
118-
features = ["The quick brown fox jumped.", "I forgot my homework."]
119-
labels = [0, 3]
120-
121-
# Use a shorter sequence length.
122-
preprocessor = keras_nlp.models.BertPreprocessor.from_preset(
123-
"bert_base_en_uncased",
124-
sequence_length=128,
125-
)
70+
classifier.predict(x=features, batch_size=2)
12671
127-
# Create a BertClassifier and fit your data.
128-
classifier = keras_nlp.models.BertClassifier.from_preset(
129-
"bert_base_en_uncased",
130-
num_classes=4,
131-
preprocessor=preprocessor,
132-
)
72+
# Re-compile (e.g., with a new learning rate).
13373
classifier.compile(
13474
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
75+
optimizer=keras.optimizers.Adam(5e-5),
76+
jit_compile=True,
13577
)
78+
# Access backbone programatically (e.g., to change `trainable`).
79+
classifier.backbone.trainable = False
80+
# Fit again.
13681
classifier.fit(x=features, y=labels, batch_size=2)
13782
```
13883
139-
Preprocessed inputs.
84+
Preprocessed integer data.
14085
```python
141-
# Create a dataset with preprocessed features in an `(x, y)` format.
142-
preprocessed_features = {
86+
features = {
14387
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
14488
"segment_ids": tf.constant(
14589
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
@@ -150,16 +94,43 @@ class BertClassifier(Task):
15094
}
15195
labels = [0, 3]
15296
153-
# Create a BERT classifier and fit your data.
97+
# Pretrained classifier without preprocessing.
15498
classifier = keras_nlp.models.BertClassifier.from_preset(
15599
"bert_base_en_uncased",
156100
num_classes=4,
157101
preprocessor=None,
158102
)
159-
classifier.compile(
160-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
103+
classifier.fit(x=features, y=labels, batch_size=2)
104+
```
105+
106+
Custom backbone and vocabulary.
107+
```python
108+
features = ["The quick brown fox jumped.", "I forgot my homework."]
109+
labels = [0, 3]
110+
111+
vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
112+
vocab += ["The", "quick", "brown", "fox", "jumped", "."]
113+
tokenizer = keras_nlp.models.BertTokenizer(
114+
vocabulary=vocab,
115+
)
116+
preprocessor = keras_nlp.models.BertPreprocessor(
117+
tokenizer=tokenizer,
118+
sequence_length=128,
161119
)
162-
classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
120+
backbone = keras_nlp.models.BertBackbone(
121+
vocabulary_size=30552,
122+
num_layers=4,
123+
num_heads=4,
124+
hidden_dim=256,
125+
intermediate_dim=512,
126+
max_sequence_length=128,
127+
)
128+
classifier = keras_nlp.models.BertClassifier(
129+
backbone=backbone,
130+
preprocessor=preprocessor,
131+
num_classes=4,
132+
)
133+
classifier.fit(x=features, y=labels, batch_size=2)
163134
```
164135
"""
165136

keras_nlp/models/bert/bert_masked_lm.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class BertMaskedLM(Task):
3737
This model will train BERT on a masked language modeling task.
3838
The model will predict labels for a number of masked tokens in the
3939
input data. For usage of this model with pre-trained weights, see the
40-
`from_preset()` method.
40+
`from_preset()` constructor.
4141
4242
This model can optionally be configured with a `preprocessor` layer, in
4343
which case inputs can be raw string features during `fit()`, `predict()`,
@@ -56,26 +56,32 @@ class BertMaskedLM(Task):
5656
5757
Example usage:
5858
59-
Raw string inputs and pretrained backbone.
59+
Raw string data.
6060
```python
61-
# Create a dataset with raw string features. Labels are inferred.
6261
features = ["The quick brown fox jumped.", "I forgot my homework."]
6362
64-
# Create a BertMaskedLM with a pretrained backbone and further train
65-
# on an MLM task.
63+
# Pretrained language model.
6664
masked_lm = keras_nlp.models.BertMaskedLM.from_preset(
67-
"bert_base_en",
65+
"bert_base_en_uncased",
6866
)
67+
masked_lm.fit(x=features, batch_size=2)
68+
69+
# Re-compile (e.g., with a new learning rate).
6970
masked_lm.compile(
7071
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
72+
optimizer=keras.optimizers.Adam(5e-5),
73+
jit_compile=True,
7174
)
75+
# Access backbone programatically (e.g., to change `trainable`).
76+
masked_lm.backbone.trainable = False
77+
# Fit again.
7278
masked_lm.fit(x=features, batch_size=2)
7379
```
7480
75-
Preprocessed inputs and custom backbone.
81+
Preprocessed integer data.
7682
```python
77-
# Create a preprocessed dataset where 0 is the mask token.
78-
preprocessed_features = {
83+
# Create preprocessed batch where 0 is the mask token.
84+
features = {
7985
"token_ids": tf.constant(
8086
[[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8)
8187
),
@@ -88,24 +94,11 @@ class BertMaskedLM(Task):
8894
# Labels are the original masked values.
8995
labels = [[3, 5]] * 2
9096
91-
# Randomly initialize a BERT encoder
92-
backbone = keras_nlp.models.BertBackbone(
93-
vocabulary_size=50265,
94-
num_layers=12,
95-
num_heads=12,
96-
hidden_dim=768,
97-
intermediate_dim=3072,
98-
max_sequence_length=12
99-
)
100-
# Create a BERT masked LM model and fit the data.
101-
masked_lm = keras_nlp.models.BertMaskedLM(
102-
backbone,
97+
masked_lm = keras_nlp.models.BertMaskedLM.from_preset(
98+
"bert_base_en_uncased",
10399
preprocessor=None,
104100
)
105-
masked_lm.compile(
106-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
107-
)
108-
masked_lm.fit(x=preprocessed_features, y=labels, batch_size=2)
101+
masked_lm.fit(x=features, y=labels, batch_size=2)
109102
```
110103
"""
111104

0 commit comments

Comments
 (0)