Skip to content

Commit 2f6e398

Browse files
authored
Add ALBERT Presets (#655)
* Add ALBERT presets * Add GDrive links * Fix preset UTs * Fix hash * Regenerate XL and XXL * Remove cruft * Add GCP URLs
1 parent 8ea419b commit 2f6e398

File tree

5 files changed

+303
-2
lines changed

5 files changed

+303
-2
lines changed

keras_nlp/models/albert/albert_backbone.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414

1515
"""ALBERT backbone model."""
1616

17+
import copy
18+
1719
import tensorflow as tf
1820
from tensorflow import keras
1921

2022
from keras_nlp.layers.position_embedding import PositionEmbedding
2123
from keras_nlp.layers.transformer_encoder import TransformerEncoder
24+
from keras_nlp.models.albert.albert_presets import backbone_presets
2225
from keras_nlp.models.backbone import Backbone
26+
from keras_nlp.utils.python_utils import classproperty
2327

2428

2529
def albert_kernel_initializer(stddev=0.02):
@@ -264,3 +268,7 @@ def get_config(self):
264268
"name": self.name,
265269
"trainable": self.trainable,
266270
}
271+
272+
@classproperty
273+
def presets(cls):
274+
return copy.deepcopy(backbone_presets)

keras_nlp/models/albert/albert_preprocessor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
# limitations under the License.
1414
"""ALBERT preprocessor layer."""
1515

16+
import copy
17+
1618
from tensorflow import keras
1719

1820
from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
21+
from keras_nlp.models.albert.albert_presets import backbone_presets
1922
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
2023
from keras_nlp.models.preprocessor import Preprocessor
2124
from keras_nlp.utils.keras_utils import (
@@ -179,3 +182,7 @@ def call(self, x, y=None, sample_weight=None):
179182
@classproperty
180183
def tokenizer_cls(cls):
181184
return AlbertTokenizer
185+
186+
@classproperty
187+
def presets(cls):
188+
return copy.deepcopy(backbone_presets)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""ALBERT model preset configurations."""
15+
16+
17+
backbone_presets = {
18+
"albert_base_en_uncased": {
19+
"config": {
20+
"vocabulary_size": 30000,
21+
"num_layers": 12,
22+
"num_heads": 12,
23+
"num_groups": 1,
24+
"num_inner_repetitions": 1,
25+
"embedding_dim": 128,
26+
"hidden_dim": 768,
27+
"intermediate_dim": 3072,
28+
"dropout": 0.0,
29+
"max_sequence_length": 512,
30+
"num_segments": 2,
31+
},
32+
"preprocessor_config": {},
33+
"description": (
34+
"Base size of ALBERT where all input is lowercased. "
35+
"Trained on English Wikipedia + BooksCorpus."
36+
),
37+
"weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_base_en_uncased/v1/model.h5",
38+
"weights_hash": "b83ccf3418dd84adc569324183176813",
39+
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_base_en_uncased/v1/vocab.spm",
40+
"spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5",
41+
},
42+
"albert_large_en_uncased": {
43+
"config": {
44+
"vocabulary_size": 30000,
45+
"num_layers": 24,
46+
"num_heads": 16,
47+
"num_groups": 1,
48+
"num_inner_repetitions": 1,
49+
"embedding_dim": 128,
50+
"hidden_dim": 1024,
51+
"intermediate_dim": 4096,
52+
"dropout": 0,
53+
"max_sequence_length": 512,
54+
"num_segments": 2,
55+
},
56+
"preprocessor_config": {},
57+
"description": (
58+
"Large size of ALBERT where all input is lowercased. "
59+
"Trained on English Wikipedia + BooksCorpus."
60+
),
61+
"weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_large_en_uncased/v1/model.h5",
62+
"weights_hash": "c7754804efb245f06dd6e7ced32e082c",
63+
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_large_en_uncased/v1/vocab.spm",
64+
"spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5",
65+
},
66+
"albert_extra_large_en_uncased": {
67+
"config": {
68+
"vocabulary_size": 30000,
69+
"num_layers": 24,
70+
"num_heads": 16,
71+
"num_groups": 1,
72+
"num_inner_repetitions": 1,
73+
"embedding_dim": 128,
74+
"hidden_dim": 2048,
75+
"intermediate_dim": 8192,
76+
"dropout": 0,
77+
"max_sequence_length": 512,
78+
"num_segments": 2,
79+
},
80+
"preprocessor_config": {},
81+
"description": (
82+
"Extra Large size of ALBERT where all input is lowercased. "
83+
"Trained on English Wikipedia + BooksCorpus."
84+
),
85+
"weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_large_en_uncased/v1/model.h5",
86+
"weights_hash": "713209be8aadfa614fd79f18c9aeb16d",
87+
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_large_en_uncased/v1/vocab.spm",
88+
"spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5",
89+
},
90+
"albert_extra_extra_large_en_uncased": {
91+
"config": {
92+
"vocabulary_size": 30000,
93+
"num_layers": 12,
94+
"num_heads": 64,
95+
"num_groups": 1,
96+
"num_inner_repetitions": 1,
97+
"embedding_dim": 128,
98+
"hidden_dim": 4096,
99+
"intermediate_dim": 16384,
100+
"dropout": 0,
101+
"max_sequence_length": 512,
102+
"num_segments": 2,
103+
},
104+
"preprocessor_config": {},
105+
"description": (
106+
"Extra Large size of ALBERT where all input is lowercased. "
107+
"Trained on English Wikipedia + BooksCorpus."
108+
),
109+
"weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_extra_large_en_uncased/v1/model.h5",
110+
"weights_hash": "a835177b692fb6a82139f94c66db2f22",
111+
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_extra_large_en_uncased/v1/vocab.spm",
112+
"spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5",
113+
},
114+
}
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for loading pretrained model presets."""
15+
16+
import pytest
17+
import tensorflow as tf
18+
from absl.testing import parameterized
19+
20+
from keras_nlp.models.albert.albert_backbone import AlbertBackbone
21+
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
22+
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
23+
24+
25+
@pytest.mark.large
26+
class AlbertPresetSmokeTest(tf.test.TestCase, parameterized.TestCase):
27+
"""
28+
A smoke test for ALBERT presets we run continuously.
29+
30+
This only tests the smallest weights we have available. Run with:
31+
`pytest keras_nlp/models/albert/albert_presets_test.py --run_large`
32+
"""
33+
34+
def test_tokenizer_output(self):
35+
tokenizer = AlbertTokenizer.from_preset(
36+
"albert_base_en_uncased",
37+
)
38+
outputs = tokenizer("The quick brown fox.")
39+
expected_outputs = [13, 1, 438, 2231, 886, 2385, 9]
40+
self.assertAllEqual(outputs, expected_outputs)
41+
42+
def test_preprocessor_output(self):
43+
preprocessor = AlbertPreprocessor.from_preset(
44+
"albert_base_en_uncased",
45+
sequence_length=4,
46+
)
47+
outputs = preprocessor("The quick brown fox.")["token_ids"]
48+
expected_outputs = [2, 13, 1, 3]
49+
self.assertAllEqual(outputs, expected_outputs)
50+
51+
@parameterized.named_parameters(
52+
("preset_weights", True), ("random_weights", False)
53+
)
54+
def test_backbone_output(self, load_weights):
55+
input_data = {
56+
"token_ids": tf.constant([[2, 13, 1, 3]]),
57+
"segment_ids": tf.constant([[0, 0, 0, 0]]),
58+
"padding_mask": tf.constant([[1, 1, 1, 1]]),
59+
}
60+
model = AlbertBackbone.from_preset(
61+
"albert_base_en_uncased", load_weights=load_weights
62+
)
63+
outputs = model(input_data)
64+
if load_weights:
65+
outputs = outputs["sequence_output"][0, 0, :5]
66+
expected = [1.830863, 1.698645, -1.819195, -0.53382, -0.38114]
67+
self.assertAllClose(outputs, expected, atol=0.01, rtol=0.01)
68+
69+
@parameterized.named_parameters(
70+
("albert_tokenizer", AlbertTokenizer),
71+
("albert_preprocessor", AlbertPreprocessor),
72+
("albert", AlbertBackbone),
73+
)
74+
def test_preset_docstring(self, cls):
75+
"""Check we did our docstring formatting correctly."""
76+
for name in cls.presets:
77+
self.assertRegex(cls.from_preset.__doc__, name)
78+
79+
@parameterized.named_parameters(
80+
("albert_tokenizer", AlbertTokenizer),
81+
("albert_preprocessor", AlbertPreprocessor),
82+
("albert", AlbertBackbone),
83+
)
84+
def test_unknown_preset_error(self, cls):
85+
# Not a preset name
86+
with self.assertRaises(ValueError):
87+
cls.from_preset("albert_base_en_uncased_clowntown")
88+
89+
90+
@pytest.mark.extra_large
91+
class AlbertPresetFullTest(tf.test.TestCase, parameterized.TestCase):
92+
"""
93+
Test the full enumeration of our preset.
94+
95+
This tests every ALBERT preset and is only run manually.
96+
Run with:
97+
`pytest keras_nlp/models/albert/albert_presets_test.py --run_extra_large`
98+
"""
99+
100+
@parameterized.named_parameters(
101+
("preset_weights", True), ("random_weights", False)
102+
)
103+
def test_load_albert(self, load_weights):
104+
for preset in AlbertBackbone.presets:
105+
model = AlbertBackbone.from_preset(
106+
preset, load_weights=load_weights
107+
)
108+
input_data = {
109+
"token_ids": tf.random.uniform(
110+
shape=(1, 512), dtype=tf.int64, maxval=model.vocabulary_size
111+
),
112+
"segment_ids": tf.constant(
113+
[0] * 200 + [1] * 312, shape=(1, 512)
114+
),
115+
"padding_mask": tf.constant([1] * 512, shape=(1, 512)),
116+
}
117+
model(input_data)
118+
119+
def test_load_tokenizers(self):
120+
for preset in AlbertTokenizer.presets:
121+
tokenizer = AlbertTokenizer.from_preset(preset)
122+
tokenizer("The quick brown fox.")
123+
124+
def test_load_preprocessors(self):
125+
for preset in AlbertPreprocessor.presets:
126+
preprocessor = AlbertPreprocessor.from_preset(preset)
127+
preprocessor("The quick brown fox.")

keras_nlp/models/albert/albert_tokenizer.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414

1515
"""ALBERT tokenizer."""
1616

17+
import copy
18+
import os
1719

1820
from tensorflow import keras
1921

22+
from keras_nlp.models.albert.albert_presets import backbone_presets
2023
from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
2124
from keras_nlp.utils.python_utils import classproperty
25+
from keras_nlp.utils.python_utils import format_docstring
2226

2327

2428
@keras.utils.register_keras_serializable(package="keras_nlp")
@@ -84,12 +88,53 @@ def __init__(self, proto, **kwargs):
8488

8589
@classproperty
8690
def presets(cls):
87-
return {}
91+
return copy.deepcopy(backbone_presets)
8892

8993
@classmethod
94+
@format_docstring(names=", ".join(backbone_presets))
9095
def from_preset(
9196
cls,
9297
preset,
9398
**kwargs,
9499
):
95-
raise NotImplementedError
100+
"""Instantiate an ALBERT tokenizer from preset vocabulary.
101+
102+
Args:
103+
preset: string. Must be one of {{names}}.
104+
105+
Examples:
106+
```python
107+
# Load a preset tokenizer.
108+
tokenizer = keras_nlp.models.AlbertTokenizer.from_preset(
109+
"albert_base_en_uncased",
110+
)
111+
112+
# Tokenize some input.
113+
tokenizer("The quick brown fox tripped.")
114+
115+
# Detokenize some input.
116+
tokenizer.detokenize([5, 6, 7, 8, 9])
117+
```
118+
"""
119+
if preset not in cls.presets:
120+
raise ValueError(
121+
"`preset` must be one of "
122+
f"""{", ".join(cls.presets)}. Received: {preset}."""
123+
)
124+
metadata = cls.presets[preset]
125+
126+
spm_proto = keras.utils.get_file(
127+
"vocab.spm",
128+
metadata["spm_proto_url"],
129+
cache_subdir=os.path.join("models", preset),
130+
file_hash=metadata["spm_proto_hash"],
131+
)
132+
133+
config = metadata["preprocessor_config"]
134+
config.update(
135+
{
136+
"proto": spm_proto,
137+
},
138+
)
139+
140+
return cls.from_config({**config, **kwargs})

0 commit comments

Comments
 (0)