Skip to content

Commit a6700eb

Browse files
Add presets for Electra and checkpoint conversion script (#1384)
* Added ElectraBackbone * Added backbone tests for ELECTRA * Fix config * Add model import to __init__ * add electra tokenizer * add tests for tokenizer * add __init__ file * add tokenizer and backbone to models __init__ * Fix Failing tokenization test * Add example on usage of the tokenizer with custom vocabulary * Add conversion script to convert weights from checkpoint * Add electra preprocessor * Add presets and tests * Add presets config with model weights * Add checkpoint conversion script * Name conversion for electra models * Update naming conventions according to preset names * Fix failing tokenizer tests * Update checkpoint conversion script according to kaggle * Add validate function * Kaggle preset * update preset link * Add electra presets * Complete run_small_preset test for electra * Add large variations of electra in presets * Fix case issues with electra presets * Fix format --------- Co-authored-by: Matt Watson <[email protected]>
1 parent da734ee commit a6700eb

File tree

9 files changed

+684
-2
lines changed

9 files changed

+684
-2
lines changed

keras_nlp/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
DistilBertTokenizer,
7373
)
7474
from keras_nlp.models.electra.electra_backbone import ElectraBackbone
75+
from keras_nlp.models.electra.electra_preprocessor import ElectraPreprocessor
7576
from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer
7677
from keras_nlp.models.f_net.f_net_backbone import FNetBackbone
7778
from keras_nlp.models.f_net.f_net_classifier import FNetClassifier

keras_nlp/models/electra/electra_backbone.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
16+
1517
from keras_nlp.api_export import keras_nlp_export
1618
from keras_nlp.backend import keras
1719
from keras_nlp.layers.modeling.position_embedding import PositionEmbedding
1820
from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding
1921
from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder
2022
from keras_nlp.models.backbone import Backbone
23+
from keras_nlp.models.electra.electra_presets import backbone_presets
2124
from keras_nlp.utils.keras_utils import gelu_approximate
25+
from keras_nlp.utils.python_utils import classproperty
2226

2327

2428
def electra_kernel_initializer(stddev=0.02):
@@ -36,8 +40,9 @@ class ElectraBackbone(Backbone):
3640
or classification task networks.
3741
3842
The default constructor gives a fully customizable, randomly initialized
39-
Electra encoder with any number of layers, heads, and embedding
40-
dimensions.
43+
ELECTRA encoder with any number of layers, heads, and embedding
44+
dimensions. To load preset architectures and weights, use the
45+
`from_preset()` constructor.
4146
4247
Disclaimer: Pre-trained models are provided on an "as is" basis, without
4348
warranties or conditions of any kind. The underlying model is provided by a
@@ -70,6 +75,13 @@ class ElectraBackbone(Backbone):
7075
"segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]]),
7176
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
7277
}
78+
79+
# Pre-trained ELECTRA encoder.
80+
model = keras_nlp.models.ElectraBackbone.from_preset(
81+
"electra_base_discriminator_en"
82+
)
83+
model(input_data)
84+
7385
# Randomly initialized Electra encoder
7486
backbone = keras_nlp.models.ElectraBackbone(
7587
vocabulary_size=1000,
@@ -234,3 +246,7 @@ def get_config(self):
234246
}
235247
)
236248
return config
249+
250+
@classproperty
251+
def presets(cls):
252+
return copy.deepcopy(backbone_presets)

keras_nlp/models/electra/electra_backbone_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,37 @@ def test_saved_model(self):
5454
init_kwargs=self.init_kwargs,
5555
input_data=self.input_data,
5656
)
57+
58+
@pytest.mark.large
59+
def test_smallest_preset(self):
60+
self.run_preset_test(
61+
cls=ElectraBackbone,
62+
preset="electra_small_discriminator_uncased_en",
63+
input_data={
64+
"token_ids": ops.array([[101, 1996, 4248, 102]], dtype="int32"),
65+
"segment_ids": ops.zeros((1, 4), dtype="int32"),
66+
"padding_mask": ops.ones((1, 4), dtype="int32"),
67+
},
68+
expected_output_shape={
69+
"sequence_output": (1, 4, 256),
70+
"pooled_output": (1, 256),
71+
},
72+
# The forward pass from a preset should be stable!
73+
expected_partial_output={
74+
"sequence_output": (
75+
ops.array([0.32287, 0.18754, -0.22272, -0.24177, 1.18977])
76+
),
77+
"pooled_output": (
78+
ops.array([-0.02974, 0.23383, 0.08430, -0.19471, 0.14822])
79+
),
80+
},
81+
)
82+
83+
@pytest.mark.extra_large
84+
def test_all_presets(self):
85+
for preset in ElectraBackbone.presets:
86+
self.run_preset_test(
87+
cls=ElectraBackbone,
88+
preset=preset,
89+
input_data=self.input_data,
90+
)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
17+
from keras_nlp.api_export import keras_nlp_export
18+
from keras_nlp.layers.preprocessing.multi_segment_packer import (
19+
MultiSegmentPacker,
20+
)
21+
from keras_nlp.models.electra.electra_presets import backbone_presets
22+
from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer
23+
from keras_nlp.models.preprocessor import Preprocessor
24+
from keras_nlp.utils.keras_utils import (
25+
convert_inputs_to_list_of_tensor_segments,
26+
)
27+
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
28+
from keras_nlp.utils.python_utils import classproperty
29+
30+
31+
@keras_nlp_export("keras_nlp.models.ElectraPreprocessor")
32+
class ElectraPreprocessor(Preprocessor):
33+
"""A ELECTRA preprocessing layer which tokenizes and packs inputs.
34+
35+
This preprocessing layer will do three things:
36+
37+
1. Tokenize any number of input segments using the `tokenizer`.
38+
2. Pack the inputs together using a `keras_nlp.layers.MultiSegmentPacker`.
39+
with the appropriate `"[CLS]"`, `"[SEP]"` and `"[PAD]"` tokens.
40+
3. Construct a dictionary of with keys `"token_ids"` and `"padding_mask"`,
41+
that can be passed directly to a ELECTRA model.
42+
43+
This layer can be used directly with `tf.data.Dataset.map` to preprocess
44+
string data in the `(x, y, sample_weight)` format used by
45+
`keras.Model.fit`.
46+
47+
Args:
48+
tokenizer: A `keras_nlp.models.ElectraTokenizer` instance.
49+
sequence_length: The length of the packed inputs.
50+
truncate: string. The algorithm to truncate a list of batched segments
51+
to fit within `sequence_length`. The value can be either
52+
`round_robin` or `waterfall`:
53+
- `"round_robin"`: Available space is assigned one token at a
54+
time in a round-robin fashion to the inputs that still need
55+
some, until the limit is reached.
56+
- `"waterfall"`: The allocation of the budget is done using a
57+
"waterfall" algorithm that allocates quota in a
58+
left-to-right manner and fills up the buckets until we run
59+
out of budget. It supports an arbitrary number of segments.
60+
61+
Call arguments:
62+
x: A tensor of single string sequences, or a tuple of multiple
63+
tensor sequences to be packed together. Inputs may be batched or
64+
unbatched. For single sequences, raw python inputs will be converted
65+
to tensors. For multiple sequences, pass tensors directly.
66+
y: Any label data. Will be passed through unaltered.
67+
sample_weight: Any label weight data. Will be passed through unaltered.
68+
69+
Examples:
70+
71+
Directly calling the layer on data.
72+
```python
73+
preprocessor = keras_nlp.models.ElectraPreprocessor.from_preset(
74+
"electra_base_discriminator_en"
75+
)
76+
preprocessor(["The quick brown fox jumped.", "Call me Ishmael."])
77+
78+
# Custom vocabulary.
79+
vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
80+
vocab += ["The", "quick", "brown", "fox", "jumped", "."]
81+
tokenizer = keras_nlp.models.ElectraTokenizer(vocabulary=vocab)
82+
preprocessor = keras_nlp.models.ElectraPreprocessor(tokenizer)
83+
preprocessor("The quick brown fox jumped.")
84+
```
85+
86+
Mapping with `tf.data.Dataset`.
87+
```python
88+
preprocessor = keras_nlp.models.ElectraPreprocessor.from_preset(
89+
"electra_base_discriminator_en"
90+
)
91+
92+
first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."])
93+
second = tf.constant(["The fox tripped.", "Oh look, a whale."])
94+
label = tf.constant([1, 1])
95+
# Map labeled single sentences.
96+
ds = tf.data.Dataset.from_tensor_slices((first, label))
97+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
98+
99+
100+
# Map unlabeled single sentences.
101+
ds = tf.data.Dataset.from_tensor_slices(first)
102+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
103+
104+
# Map labeled sentence pairs.
105+
ds = tf.data.Dataset.from_tensor_slices(((first, second), label))
106+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
107+
# Map unlabeled sentence pairs.
108+
ds = tf.data.Dataset.from_tensor_slices((first, second))
109+
110+
# Watch out for tf.data's default unpacking of tuples here!
111+
# Best to invoke the `preprocessor` directly in this case.
112+
ds = ds.map(
113+
lambda first, second: preprocessor(x=(first, second)),
114+
num_parallel_calls=tf.data.AUTOTUNE,
115+
)
116+
```
117+
"""
118+
119+
def __init__(
120+
self,
121+
tokenizer,
122+
sequence_length=512,
123+
truncate="round_robin",
124+
**kwargs,
125+
):
126+
super().__init__(**kwargs)
127+
self.tokenizer = tokenizer
128+
self.packer = MultiSegmentPacker(
129+
start_value=self.tokenizer.cls_token_id,
130+
end_value=self.tokenizer.sep_token_id,
131+
pad_value=self.tokenizer.pad_token_id,
132+
truncate=truncate,
133+
sequence_length=sequence_length,
134+
)
135+
136+
def get_config(self):
137+
config = super().get_config()
138+
config.update(
139+
{
140+
"sequence_length": self.packer.sequence_length,
141+
"truncate": self.packer.truncate,
142+
}
143+
)
144+
return config
145+
146+
def call(self, x, y=None, sample_weight=None):
147+
x = convert_inputs_to_list_of_tensor_segments(x)
148+
x = [self.tokenizer(segment) for segment in x]
149+
token_ids, segment_ids = self.packer(x)
150+
x = {
151+
"token_ids": token_ids,
152+
"segment_ids": segment_ids,
153+
"padding_mask": token_ids != self.tokenizer.pad_token_id,
154+
}
155+
return pack_x_y_sample_weight(x, y, sample_weight)
156+
157+
@classproperty
158+
def tokenizer_cls(cls):
159+
return ElectraTokenizer
160+
161+
@classproperty
162+
def presets(cls):
163+
return copy.deepcopy({**backbone_presets})
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from keras_nlp.models.electra.electra_preprocessor import ElectraPreprocessor
18+
from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer
19+
from keras_nlp.tests.test_case import TestCase
20+
21+
22+
class ElectraPreprocessorTest(TestCase):
23+
def setUp(self):
24+
self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
25+
self.vocab += ["THE", "QUICK", "BROWN", "FOX"]
26+
self.vocab += ["the", "quick", "brown", "fox"]
27+
self.tokenizer = ElectraTokenizer(vocabulary=self.vocab)
28+
self.init_kwargs = {
29+
"tokenizer": self.tokenizer,
30+
"sequence_length": 8,
31+
}
32+
self.input_data = (
33+
["THE QUICK BROWN FOX."],
34+
[1], # Pass through labels.
35+
[1.0], # Pass through sample_weights.
36+
)
37+
38+
def test_preprocessor_basics(self):
39+
self.run_preprocessing_layer_test(
40+
cls=ElectraPreprocessor,
41+
init_kwargs=self.init_kwargs,
42+
input_data=self.input_data,
43+
expected_output=(
44+
{
45+
"token_ids": [[2, 5, 6, 7, 8, 1, 3, 0]],
46+
"segment_ids": [[0, 0, 0, 0, 0, 0, 0, 0]],
47+
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
48+
},
49+
[1], # Pass through labels.
50+
[1.0], # Pass through sample_weights.
51+
),
52+
)
53+
54+
def test_errors_for_2d_list_input(self):
55+
preprocessor = ElectraPreprocessor(**self.init_kwargs)
56+
ambiguous_input = [["one", "two"], ["three", "four"]]
57+
with self.assertRaises(ValueError):
58+
preprocessor(ambiguous_input)
59+
60+
@pytest.mark.extra_large
61+
def test_all_presets(self):
62+
for preset in ElectraPreprocessor.presets:
63+
self.run_preset_test(
64+
cls=ElectraPreprocessor,
65+
preset=preset,
66+
input_data=self.input_data,
67+
)

0 commit comments

Comments
 (0)