Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions keras_nlp/models/albert/albert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@

"""ALBERT backbone model."""

import copy

import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers.position_embedding import PositionEmbedding
from keras_nlp.layers.transformer_encoder import TransformerEncoder
from keras_nlp.models.albert.albert_presets import backbone_presets
from keras_nlp.models.backbone import Backbone
from keras_nlp.utils.python_utils import classproperty


def albert_kernel_initializer(stddev=0.02):
Expand Down Expand Up @@ -264,3 +268,7 @@ def get_config(self):
"name": self.name,
"trainable": self.trainable,
}

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
7 changes: 7 additions & 0 deletions keras_nlp/models/albert/albert_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# limitations under the License.
"""ALBERT preprocessor layer."""

import copy

from tensorflow import keras

from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
from keras_nlp.models.albert.albert_presets import backbone_presets
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import (
Expand Down Expand Up @@ -179,3 +182,7 @@ def call(self, x, y=None, sample_weight=None):
@classproperty
def tokenizer_cls(cls):
return AlbertTokenizer

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
114 changes: 114 additions & 0 deletions keras_nlp/models/albert/albert_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ALBERT model preset configurations."""


backbone_presets = {
"albert_base_en_uncased": {
"config": {
"vocabulary_size": 30000,
"num_layers": 12,
"num_heads": 12,
"num_groups": 1,
"num_inner_repetitions": 1,
"embedding_dim": 128,
"hidden_dim": 768,
"intermediate_dim": 3072,
"dropout": 0.0,
"max_sequence_length": 512,
"num_segments": 2,
},
"preprocessor_config": {},
"description": (
"Base size of ALBERT where all input is lowercased. "
"Trained on English Wikipedia + BooksCorpus."
),
"weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_base_en_uncased/v1/model.h5",
"weights_hash": "b83ccf3418dd84adc569324183176813",
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_base_en_uncased/v1/vocab.spm",
"spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5",
},
"albert_large_en_uncased": {
"config": {
"vocabulary_size": 30000,
"num_layers": 24,
"num_heads": 16,
"num_groups": 1,
"num_inner_repetitions": 1,
"embedding_dim": 128,
"hidden_dim": 1024,
"intermediate_dim": 4096,
"dropout": 0,
"max_sequence_length": 512,
"num_segments": 2,
},
"preprocessor_config": {},
"description": (
"Large size of ALBERT where all input is lowercased. "
"Trained on English Wikipedia + BooksCorpus."
),
"weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_large_en_uncased/v1/model.h5",
"weights_hash": "c7754804efb245f06dd6e7ced32e082c",
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_large_en_uncased/v1/vocab.spm",
"spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5",
},
"albert_extra_large_en_uncased": {
"config": {
"vocabulary_size": 30000,
"num_layers": 24,
"num_heads": 16,
"num_groups": 1,
"num_inner_repetitions": 1,
"embedding_dim": 128,
"hidden_dim": 2048,
"intermediate_dim": 8192,
"dropout": 0,
"max_sequence_length": 512,
"num_segments": 2,
},
"preprocessor_config": {},
"description": (
"Extra Large size of ALBERT where all input is lowercased. "
"Trained on English Wikipedia + BooksCorpus."
),
"weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_large_en_uncased/v1/model.h5",
"weights_hash": "713209be8aadfa614fd79f18c9aeb16d",
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_large_en_uncased/v1/vocab.spm",
"spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5",
},
"albert_extra_extra_large_en_uncased": {
"config": {
"vocabulary_size": 30000,
"num_layers": 12,
"num_heads": 64,
"num_groups": 1,
"num_inner_repetitions": 1,
"embedding_dim": 128,
"hidden_dim": 4096,
"intermediate_dim": 16384,
"dropout": 0,
"max_sequence_length": 512,
"num_segments": 2,
},
"preprocessor_config": {},
"description": (
"Extra Large size of ALBERT where all input is lowercased. "
"Trained on English Wikipedia + BooksCorpus."
),
"weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_extra_large_en_uncased/v1/model.h5",
"weights_hash": "a835177b692fb6a82139f94c66db2f22",
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_extra_large_en_uncased/v1/vocab.spm",
"spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5",
},
}
127 changes: 127 additions & 0 deletions keras_nlp/models/albert/albert_presets_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for loading pretrained model presets."""

import pytest
import tensorflow as tf
from absl.testing import parameterized

from keras_nlp.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer


@pytest.mark.large
class AlbertPresetSmokeTest(tf.test.TestCase, parameterized.TestCase):
"""
A smoke test for ALBERT presets we run continuously.

This only tests the smallest weights we have available. Run with:
`pytest keras_nlp/models/albert/albert_presets_test.py --run_large`
"""

def test_tokenizer_output(self):
tokenizer = AlbertTokenizer.from_preset(
"albert_base_en_uncased",
)
outputs = tokenizer("The quick brown fox.")
expected_outputs = [13, 1, 438, 2231, 886, 2385, 9]
self.assertAllEqual(outputs, expected_outputs)

def test_preprocessor_output(self):
preprocessor = AlbertPreprocessor.from_preset(
"albert_base_en_uncased",
sequence_length=4,
)
outputs = preprocessor("The quick brown fox.")["token_ids"]
expected_outputs = [2, 13, 1, 3]
self.assertAllEqual(outputs, expected_outputs)

@parameterized.named_parameters(
("preset_weights", True), ("random_weights", False)
)
def test_backbone_output(self, load_weights):
input_data = {
"token_ids": tf.constant([[2, 13, 1, 3]]),
"segment_ids": tf.constant([[0, 0, 0, 0]]),
"padding_mask": tf.constant([[1, 1, 1, 1]]),
}
model = AlbertBackbone.from_preset(
"albert_base_en_uncased", load_weights=load_weights
)
outputs = model(input_data)
if load_weights:
outputs = outputs["sequence_output"][0, 0, :5]
expected = [1.830863, 1.698645, -1.819195, -0.53382, -0.38114]
self.assertAllClose(outputs, expected, atol=0.01, rtol=0.01)

@parameterized.named_parameters(
("albert_tokenizer", AlbertTokenizer),
("albert_preprocessor", AlbertPreprocessor),
("albert", AlbertBackbone),
)
def test_preset_docstring(self, cls):
"""Check we did our docstring formatting correctly."""
for name in cls.presets:
self.assertRegex(cls.from_preset.__doc__, name)

@parameterized.named_parameters(
("albert_tokenizer", AlbertTokenizer),
("albert_preprocessor", AlbertPreprocessor),
("albert", AlbertBackbone),
)
def test_unknown_preset_error(self, cls):
# Not a preset name
with self.assertRaises(ValueError):
cls.from_preset("albert_base_en_uncased_clowntown")


@pytest.mark.extra_large
class AlbertPresetFullTest(tf.test.TestCase, parameterized.TestCase):
"""
Test the full enumeration of our preset.

This tests every ALBERT preset and is only run manually.
Run with:
`pytest keras_nlp/models/albert/albert_presets_test.py --run_extra_large`
"""

@parameterized.named_parameters(
("preset_weights", True), ("random_weights", False)
)
def test_load_albert(self, load_weights):
for preset in AlbertBackbone.presets:
model = AlbertBackbone.from_preset(
preset, load_weights=load_weights
)
input_data = {
"token_ids": tf.random.uniform(
shape=(1, 512), dtype=tf.int64, maxval=model.vocabulary_size
),
"segment_ids": tf.constant(
[0] * 200 + [1] * 312, shape=(1, 512)
),
"padding_mask": tf.constant([1] * 512, shape=(1, 512)),
}
model(input_data)

def test_load_tokenizers(self):
for preset in AlbertTokenizer.presets:
tokenizer = AlbertTokenizer.from_preset(preset)
tokenizer("The quick brown fox.")

def test_load_preprocessors(self):
for preset in AlbertPreprocessor.presets:
preprocessor = AlbertPreprocessor.from_preset(preset)
preprocessor("The quick brown fox.")
49 changes: 47 additions & 2 deletions keras_nlp/models/albert/albert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

"""ALBERT tokenizer."""

import copy
import os

from tensorflow import keras

from keras_nlp.models.albert.albert_presets import backbone_presets
from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.python_utils import format_docstring


@keras.utils.register_keras_serializable(package="keras_nlp")
Expand Down Expand Up @@ -84,12 +88,53 @@ def __init__(self, proto, **kwargs):

@classproperty
def presets(cls):
return {}
return copy.deepcopy(backbone_presets)

@classmethod
@format_docstring(names=", ".join(backbone_presets))
def from_preset(
cls,
preset,
**kwargs,
):
raise NotImplementedError
"""Instantiate an ALBERT tokenizer from preset vocabulary.

Args:
preset: string. Must be one of {{names}}.

Examples:
```python
# Load a preset tokenizer.
tokenizer = keras_nlp.models.AlbertTokenizer.from_preset(
"albert_base_en_uncased",
)

# Tokenize some input.
tokenizer("The quick brown fox tripped.")

# Detokenize some input.
tokenizer.detokenize([5, 6, 7, 8, 9])
```
"""
if preset not in cls.presets:
raise ValueError(
"`preset` must be one of "
f"""{", ".join(cls.presets)}. Received: {preset}."""
)
metadata = cls.presets[preset]

spm_proto = keras.utils.get_file(
"vocab.spm",
metadata["spm_proto_url"],
cache_subdir=os.path.join("models", preset),
file_hash=metadata["spm_proto_hash"],
)

config = metadata["preprocessor_config"]
config.update(
{
"proto": spm_proto,
},
)

return cls.from_config({**config, **kwargs})